Compare commits

..

1 Commits

Author SHA1 Message Date
Timmy
e334c5256c feat: marathon session limits — cap, checkpoint, rotate (#326)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 54s
- Add max_messages (default 200) to SessionResetPolicy
- Track message_count in SessionEntry (persisted to sessions.json)
- Add 'message_limit' reset reason to _should_reset
- Auto-checkpoint filesystem before session rotation
- Inject near-limit warnings (85%/100%) into agent ephemeral prompt
- Auto-rotate sessions when message cap is hit
- Add get_message_limit_info() and reset_message_count() APIs
- 24 new tests covering all limit behaviors

Evidence: 170 sessions exceed 100 msgs, longest 1,643 msgs (40h).
Marathon sessions show 45-84% error rates from tool fixation.
Cap + checkpoint + restart breaks the death spiral.
2026-04-13 18:51:23 -04:00
7 changed files with 316 additions and 990 deletions

View File

@@ -1,5 +0,0 @@
"""Hermes daemon services — long-running background processes."""
from daemon.confirmation_server import ConfirmationServer
__all__ = ["ConfirmationServer"]

View File

@@ -1,664 +0,0 @@
"""Human Confirmation Daemon — route high-risk actions through human review.
HTTP server on port 6000 that intercepts dangerous operations and holds them
until a human approves or denies. Integrates with the existing approval
system (tools/approval.py) and notifies humans via Telegram/Discord/CLI.
Endpoints:
POST /confirm — submit a high-risk action for review
POST /confirm/{id}/approve — approve a pending confirmation
POST /confirm/{id}/deny — deny a pending confirmation
GET /confirm/{id} — check status of a confirmation
GET /audit — recent audit log entries
GET /health — liveness probe
Every decision is logged to SQLite for audit.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import threading
import time
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
try:
from aiohttp import web
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
web = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 6000
# Actions that always require human confirmation (not bypassable)
HIGH_RISK_ACTIONS = {
"deploy_production",
"delete_data",
"transfer_funds",
"modify_permissions",
"shutdown_service",
"wipe_database",
"exec_remote",
"publish_package",
"rotate_keys",
"migrate_database",
}
# Default rate limits: max N confirmations per action type per window
DEFAULT_RATE_LIMIT = 10 # max confirmations per action type
RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------
@dataclass
class ConfirmationRequest:
"""A single pending or resolved confirmation."""
id: str
action: str
description: str
details: Dict[str, Any] = field(default_factory=dict)
requester: str = "" # agent or user who requested
session_key: str = ""
status: str = "pending" # pending | approved | denied | expired
resolved_by: str = ""
resolved_at: Optional[float] = None
created_at: float = field(default_factory=time.time)
timeout_seconds: int = 300 # 5 min default
def to_dict(self) -> dict:
d = asdict(self)
d["created_at_iso"] = _ts_to_iso(d["created_at"])
d["resolved_at_iso"] = _ts_to_iso(d["resolved_at"]) if d["resolved_at"] else None
return d
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
# ---------------------------------------------------------------------------
# Audit log (SQLite)
# ---------------------------------------------------------------------------
class AuditLog:
"""SQLite-backed audit log for all confirmation decisions."""
def __init__(self, db_path: Optional[str] = None):
if db_path is None:
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
home.mkdir(parents=True, exist_ok=True)
db_path = str(home / "confirmation_audit.db")
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("""
CREATE TABLE IF NOT EXISTS audit_log (
id TEXT PRIMARY KEY,
action TEXT NOT NULL,
description TEXT NOT NULL,
details TEXT NOT NULL DEFAULT '{}',
requester TEXT NOT NULL DEFAULT '',
session_key TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL,
resolved_by TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL,
resolved_at REAL,
resolved_at_iso TEXT,
created_at_iso TEXT
)
""")
self._conn.commit()
def log(self, req: ConfirmationRequest) -> None:
d = req.to_dict()
self._conn.execute(
"""INSERT OR REPLACE INTO audit_log
(id, action, description, details, requester, session_key,
status, resolved_by, created_at, resolved_at,
resolved_at_iso, created_at_iso)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
d["id"], d["action"], d["description"],
json.dumps(d["details"]), d["requester"], d["session_key"],
d["status"], d["resolved_by"],
d["created_at"], d["resolved_at"],
d["resolved_at_iso"], d["created_at_iso"],
),
)
self._conn.commit()
def recent(self, limit: int = 50) -> List[dict]:
rows = self._conn.execute(
"SELECT * FROM audit_log ORDER BY created_at DESC LIMIT ?", (limit,)
).fetchall()
cols = [d[0] for d in self._conn.description]
return [dict(zip(cols, row)) for row in rows]
def close(self) -> None:
self._conn.close()
# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------
class RateLimiter:
"""Simple sliding-window rate limiter per action type."""
def __init__(self, max_per_window: int = DEFAULT_RATE_LIMIT,
window: int = RATE_LIMIT_WINDOW):
self._max = max_per_window
self._window = window
self._timestamps: Dict[str, List[float]] = {} # action -> [ts, ...]
self._lock = threading.Lock()
def check(self, action: str) -> bool:
"""Return True if the action is within rate limits."""
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
# Prune expired
timestamps = [t for t in timestamps if now - t < self._window]
self._timestamps[action] = timestamps
if len(timestamps) >= self._max:
return False
timestamps.append(now)
return True
def remaining(self, action: str) -> int:
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
timestamps = [t for t in timestamps if now - t < self._window]
return max(0, self._max - len(timestamps))
# ---------------------------------------------------------------------------
# Notification dispatcher
# ---------------------------------------------------------------------------
async def _notify_human(request: ConfirmationRequest) -> None:
"""Send a notification about a pending confirmation to humans.
Tries Telegram first, then Discord, then falls back to log warning.
Uses the existing send_message infrastructure.
"""
msg = (
f"\U0001f514 Confirmation Required\n"
f"Action: {request.action}\n"
f"Description: {request.description}\n"
f"Requester: {request.requester}\n"
f"ID: {request.id}\n\n"
f"Approve: POST /confirm/{request.id}/approve\n"
f"Deny: POST /confirm/{request.id}/deny"
)
sent = False
# Try Telegram
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "telegram",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Telegram", request.id)
except Exception as e:
logger.debug("Telegram notify failed: %s", e)
# Try Discord if Telegram failed
if not sent:
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "discord",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Discord", request.id)
except Exception as e:
logger.debug("Discord notify failed: %s", e)
if not sent:
logger.warning(
"Confirmation %s: no messaging channel available. "
"Action '%s' requires human review -- check /confirm/%s",
request.id, request.action, request.id,
)
# ---------------------------------------------------------------------------
# Whitelist manager
# ---------------------------------------------------------------------------
class Whitelist:
"""Configurable whitelist for actions that skip confirmation."""
def __init__(self):
self._allowed: Dict[str, set] = {} # session_key -> {action, ...}
self._global_allowed: set = set()
self._lock = threading.Lock()
def is_whitelisted(self, action: str, session_key: str = "") -> bool:
with self._lock:
if action in self._global_allowed:
return True
if session_key and action in self._allowed.get(session_key, set()):
return True
return False
def add(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.setdefault(session_key, set()).add(action)
else:
self._global_allowed.add(action)
def remove(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.get(session_key, set()).discard(action)
else:
self._global_allowed.discard(action)
# ---------------------------------------------------------------------------
# Confirmation Server
# ---------------------------------------------------------------------------
class ConfirmationServer:
"""HTTP server for human confirmation of high-risk actions.
Usage:
server = ConfirmationServer()
await server.start() # blocks
# or:
server.start_background() # non-blocking
...
await server.stop()
"""
def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT,
db_path: Optional[str] = None,
rate_limit: int = DEFAULT_RATE_LIMIT):
if not AIOHTTP_AVAILABLE:
raise RuntimeError(
"aiohttp is required for the confirmation daemon. "
"Install with: pip install aiohttp"
)
self._host = host
self._port = port
self._audit = AuditLog(db_path)
self._rate_limiter = RateLimiter(max_per_window=rate_limit)
self._whitelist = Whitelist()
self._pending: Dict[str, ConfirmationRequest] = {}
self._lock = threading.Lock()
self._app: Optional[web.Application] = None
self._runner: Optional[web.AppRunner] = None
self._bg_thread: Optional[threading.Thread] = None
# --- Lifecycle ---
def _build_app(self) -> web.Application:
app = web.Application(client_max_size=1_048_576) # 1 MB
app["server"] = self
app.router.add_post("/confirm", self._handle_submit)
app.router.add_post("/confirm/{req_id}/approve", self._handle_approve)
app.router.add_post("/confirm/{req_id}/deny", self._handle_deny)
app.router.add_get("/confirm/{req_id}", self._handle_status)
app.router.add_get("/audit", self._handle_audit)
app.router.add_get("/health", self._handle_health)
return app
async def start(self) -> None:
"""Start the server and block until stopped."""
self._app = self._build_app()
self._runner = web.AppRunner(self._app)
await self._runner.setup()
site = web.TCPSite(self._runner, self._host, self._port)
await site.start()
logger.info(
"Confirmation daemon listening on http://%s:%s",
self._host, self._port,
)
# Run until cancelled
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
pass
finally:
await self.stop()
def start_background(self) -> None:
"""Start the server in a background thread (non-blocking)."""
def _run():
asyncio.run(self.start())
self._bg_thread = threading.Thread(target=_run, daemon=True)
self._bg_thread.start()
async def stop(self) -> None:
"""Gracefully stop the server."""
if self._runner:
await self._runner.cleanup()
self._runner = None
self._audit.close()
logger.info("Confirmation daemon stopped")
# --- Internal helpers ---
def _get_request(self, req_id: str) -> Optional[ConfirmationRequest]:
with self._lock:
return self._pending.get(req_id)
def _expire_old_requests(self) -> int:
"""Mark expired requests. Returns count expired."""
now = time.time()
expired = 0
with self._lock:
for req in list(self._pending.values()):
if req.status == "pending" and (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
expired += 1
return expired
# --- HTTP handlers ---
async def _handle_submit(self, request: web.Request) -> web.Response:
"""POST /confirm -- submit a new confirmation request."""
try:
body = await request.json()
except Exception:
return web.json_response(
{"error": "Invalid JSON body"}, status=400
)
action = (body.get("action") or "").strip()
description = (body.get("description") or "").strip()
details = body.get("details") or {}
requester = (body.get("requester") or "agent").strip()
session_key = (body.get("session_key") or "").strip()
timeout = body.get("timeout_seconds", 300)
if not action:
return web.json_response(
{"error": "Field 'action' is required"}, status=400
)
if not description:
return web.json_response(
{"error": "Field 'description' is required"}, status=400
)
# Whitelist check
if self._whitelist.is_whitelisted(action, session_key):
auto_id = str(uuid.uuid4())[:8]
return web.json_response({
"id": auto_id,
"action": action,
"status": "auto_approved",
"message": f"Action '{action}' is whitelisted for this session",
})
# Rate limit check
if not self._rate_limiter.check(action):
remaining = self._rate_limiter.remaining(action)
return web.json_response({
"error": f"Rate limit exceeded for action '{action}'",
"remaining": remaining,
"window_seconds": RATE_LIMIT_WINDOW,
}, status=429)
# Enforce timeout bounds
try:
timeout = max(30, min(int(timeout), 3600))
except (ValueError, TypeError):
timeout = 300
# Create request
req = ConfirmationRequest(
id=str(uuid.uuid4())[:12],
action=action,
description=description,
details=details,
requester=requester,
session_key=session_key,
timeout_seconds=timeout,
)
with self._lock:
self._pending[req.id] = req
# Audit log
self._audit.log(req)
# Notify humans (fire-and-forget)
asyncio.create_task(_notify_human(req))
logger.info(
"Confirmation %s submitted: action=%s requester=%s",
req.id, action, requester,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": req.status,
"timeout_seconds": req.timeout_seconds,
"message": "Confirmation pending. Awaiting human review.",
"approve_url": f"/confirm/{req.id}/approve",
"deny_url": f"/confirm/{req.id}/deny",
}, status=202)
async def _handle_approve(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/approve -- approve a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
# Parse optional approver identity
try:
body = await request.json()
approver = (body.get("approver") or "api").strip()
except Exception:
approver = "api"
req.status = "approved"
req.resolved_by = approver
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s APPROVED by %s (action=%s)",
req_id, approver, req.action,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": "approved",
"resolved_by": approver,
"resolved_at": req.resolved_at,
})
async def _handle_deny(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/deny -- deny a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
try:
body = await request.json()
denier = (body.get("denier") or "api").strip()
reason = (body.get("reason") or "").strip()
except Exception:
denier = "api"
reason = ""
req.status = "denied"
req.resolved_by = denier
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s DENIED by %s (action=%s, reason=%s)",
req_id, denier, req.action, reason,
)
resp = {
"id": req.id,
"action": req.action,
"status": "denied",
"resolved_by": denier,
"resolved_at": req.resolved_at,
}
if reason:
resp["reason"] = reason
return web.json_response(resp)
async def _handle_status(self, request: web.Request) -> web.Response:
"""GET /confirm/{id} -- check status of a confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
# Check for expiration
if req.status == "pending":
now = time.time()
if (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
return web.json_response(req.to_dict())
async def _handle_audit(self, request: web.Request) -> web.Response:
"""GET /audit -- recent audit log entries."""
try:
limit = int(request.query.get("limit", "50"))
limit = max(1, min(limit, 500))
except (ValueError, TypeError):
limit = 50
entries = self._audit.recent(limit)
return web.json_response({
"count": len(entries),
"entries": entries,
})
async def _handle_health(self, request: web.Request) -> web.Response:
"""GET /health -- liveness probe."""
return web.json_response({
"status": "ok",
"pending_count": len(self._pending),
"timestamp": time.time(),
})
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
"""Run the confirmation daemon as a standalone process."""
import argparse
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
parser = argparse.ArgumentParser(
description="Hermes Human Confirmation Daemon"
)
parser.add_argument(
"--host", default=os.getenv("CONFIRMATION_HOST", DEFAULT_HOST),
help="Bind address (default: 127.0.0.1)",
)
parser.add_argument(
"--port", type=int,
default=int(os.getenv("CONFIRMATION_PORT", DEFAULT_PORT)),
help="Bind port (default: 6000)",
)
parser.add_argument(
"--db-path", default=None,
help="SQLite database path (default: ~/.hermes/confirmation_audit.db)",
)
parser.add_argument(
"--rate-limit", type=int,
default=int(os.getenv("CONFIRMATION_RATE_LIMIT", DEFAULT_RATE_LIMIT)),
help="Max confirmations per action per hour (default: 10)",
)
args = parser.parse_args()
if not AIOHTTP_AVAILABLE:
print("ERROR: aiohttp is required. Install with: pip install aiohttp")
raise SystemExit(1)
server = ConfirmationServer(
host=args.host,
port=args.port,
db_path=args.db_path,
rate_limit=args.rate_limit,
)
print(f"Starting Confirmation Daemon on http://{args.host}:{args.port}")
asyncio.run(server.start())
if __name__ == "__main__":
main()

View File

@@ -107,6 +107,7 @@ class SessionResetPolicy:
mode: str = "both" # "daily", "idle", "both", or "none"
at_hour: int = 4 # Hour for daily reset (0-23, local time)
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
max_messages: int = 200 # Max messages per session before forced checkpoint+restart (0 = unlimited)
notify: bool = True # Send a notification to the user when auto-reset occurs
notify_exclude_platforms: tuple = ("api_server", "webhook") # Platforms that don't get reset notifications
@@ -115,6 +116,7 @@ class SessionResetPolicy:
"mode": self.mode,
"at_hour": self.at_hour,
"idle_minutes": self.idle_minutes,
"max_messages": self.max_messages,
"notify": self.notify,
"notify_exclude_platforms": list(self.notify_exclude_platforms),
}
@@ -125,12 +127,14 @@ class SessionResetPolicy:
mode = data.get("mode")
at_hour = data.get("at_hour")
idle_minutes = data.get("idle_minutes")
max_messages = data.get("max_messages")
notify = data.get("notify")
exclude = data.get("notify_exclude_platforms")
return cls(
mode=mode if mode is not None else "both",
at_hour=at_hour if at_hour is not None else 4,
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
max_messages=max_messages if max_messages is not None else 200,
notify=notify if notify is not None else True,
notify_exclude_platforms=tuple(exclude) if exclude is not None else ("api_server", "webhook"),
)

View File

@@ -2343,6 +2343,12 @@ class GatewayRunner:
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
if reset_reason == "daily":
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
elif reset_reason == "message_limit":
context_note = (
"[System note: The user's previous session reached the message limit "
"and was automatically checkpointed and rotated. This is a fresh session. "
"If the user references something from before, you can search session history.]"
)
else:
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
context_prompt = context_note + "\n\n" + context_prompt
@@ -2368,16 +2374,18 @@ class GatewayRunner:
if adapter:
if reset_reason == "daily":
reason_text = f"daily schedule at {policy.at_hour}:00"
elif reset_reason == "message_limit":
reason_text = f"reached {policy.max_messages} message limit"
else:
hours = policy.idle_minutes // 60
mins = policy.idle_minutes % 60
duration = f"{hours}h" if not mins else f"{hours}h {mins}m" if hours else f"{mins}m"
reason_text = f"inactive for {duration}"
notice = (
f"◐ Session automatically reset ({reason_text}). "
f"Conversation history cleared.\n"
f"◐ Session automatically rotated ({reason_text}). "
f"Conversation was preserved via checkpoint.\n"
f"Use /resume to browse and restore a previous session.\n"
f"Adjust reset timing in config.yaml under session_reset."
f"Adjust limits in config.yaml under session_reset."
)
try:
session_info = self._format_session_info()
@@ -3073,6 +3081,39 @@ class GatewayRunner:
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
)
# Marathon session limit (#326): check if we hit the message cap.
# Auto-checkpoint filesystem and rotate session.
try:
_post_limit = self.session_store.get_message_limit_info(session_key)
if _post_limit["at_limit"] and _post_limit["max_messages"] > 0:
logger.info(
"[Marathon] Session %s hit message limit (%d/%d). Rotating.",
session_key, _post_limit["message_count"], _post_limit["max_messages"],
)
# Attempt filesystem checkpoint before rotation
try:
from tools.checkpoint_manager import CheckpointManager
_cp_cfg_path = _hermes_home / "config.yaml"
if _cp_cfg_path.exists():
import yaml as _cp_yaml
with open(_cp_cfg_path, encoding="utf-8") as _cpf:
_cp_data = _cp_yaml.safe_load(_cpf) or {}
_cp_settings = _cp_data.get("checkpoints", {})
if _cp_settings.get("enabled"):
_cwd = _cp_settings.get("working_dir") or os.getcwd()
mgr = CheckpointManager(max_checkpoints=_cp_settings.get("max_checkpoints", 20))
cp = mgr.create_checkpoint(str(_cwd), label=f"marathon-{session_entry.session_id[:8]}")
if cp:
logger.info("[Marathon] Checkpoint: %s", cp.label)
except Exception as cp_err:
logger.debug("[Marathon] Checkpoint failed (non-fatal): %s", cp_err)
new_entry = self.session_store.reset_session(session_key)
if new_entry:
logger.info("[Marathon] Rotated: %s -> %s", session_entry.session_id, new_entry.session_id)
except Exception as rot_err:
logger.debug("[Marathon] Rotation check failed: %s", rot_err)
# Auto voice reply: send TTS audio before the text response
_already_sent = bool(agent_result.get("already_sent"))
if self._should_send_voice_reply(event, response, agent_messages, already_sent=_already_sent):
@@ -6538,6 +6579,26 @@ class GatewayRunner:
if self._ephemeral_system_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip()
# Marathon session limit warning (#326)
try:
_limit_info = self.session_store.get_message_limit_info(session_key)
if _limit_info["near_limit"] and not _limit_info["at_limit"]:
_remaining = _limit_info["remaining"]
_limit_warn = (
f"[SESSION LIMIT: This session has {_limit_info['message_count']} messages. "
f"Only {_remaining} message(s) remain before automatic session rotation at "
f"{_limit_info['max_messages']} messages. Start wrapping up and save important state.]"
)
combined_ephemeral = (combined_ephemeral + "\n\n" + _limit_warn).strip()
elif _limit_info["at_limit"]:
_limit_warn = (
f"[SESSION LIMIT REACHED: This session has hit the {_limit_info['max_messages']} "
f"message limit. This is your FINAL response. Summarize accomplishments and next steps.]"
)
combined_ephemeral = (combined_ephemeral + "\n\n" + _limit_warn).strip()
except Exception:
pass
# Re-read .env and config for fresh credentials (gateway is long-lived,
# keys may change without restart).
try:

View File

@@ -383,7 +383,11 @@ class SessionEntry:
# survives gateway restarts (the old in-memory _pre_flushed_sessions
# set was lost on restart, causing redundant re-flushes).
memory_flushed: bool = False
# Marathon session limit tracking (#326).
# Counts total messages (user + assistant + tool) in this session.
message_count: int = 0
def to_dict(self) -> Dict[str, Any]:
result = {
"session_key": self.session_key,
@@ -402,6 +406,7 @@ class SessionEntry:
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
"message_count": self.message_count,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@@ -438,6 +443,7 @@ class SessionEntry:
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
message_count=data.get("message_count", 0),
)
@@ -643,6 +649,9 @@ class SessionStore:
)
if policy.mode == "none":
# Even with mode=none, enforce message_limit if set
if policy.max_messages > 0 and entry.message_count >= policy.max_messages:
return "message_limit"
return None
now = _now()
@@ -664,7 +673,11 @@ class SessionStore:
if entry.updated_at < today_reset:
return "daily"
# Marathon session limit (#326): force checkpoint+restart at max_messages
if policy.max_messages > 0 and entry.message_count >= policy.max_messages:
return "message_limit"
return None
def has_any_sessions(self) -> bool:
@@ -822,6 +835,43 @@ class SessionStore:
entry.last_prompt_tokens = last_prompt_tokens
self._save()
def get_message_limit_info(self, session_key: str) -> Dict[str, Any]:
"""Get message count and limit info for a session (#326)."""
with self._lock:
self._ensure_loaded_locked()
entry = self._entries.get(session_key)
if not entry:
return {"message_count": 0, "max_messages": 0, "remaining": 0,
"near_limit": False, "at_limit": False, "threshold": 0.0}
policy = self.config.get_reset_policy(
platform=entry.platform,
session_type=entry.chat_type,
)
max_msgs = policy.max_messages
count = entry.message_count
remaining = max(0, max_msgs - count) if max_msgs > 0 else float("inf")
threshold = count / max_msgs if max_msgs > 0 else 0.0
return {
"message_count": count,
"max_messages": max_msgs,
"remaining": remaining,
"near_limit": max_msgs > 0 and count >= int(max_msgs * 0.85),
"at_limit": max_msgs > 0 and count >= max_msgs,
"threshold": threshold,
}
def reset_message_count(self, session_key: str) -> None:
"""Reset the message count to zero for a session (#326)."""
with self._lock:
self._ensure_loaded_locked()
entry = self._entries.get(session_key)
if entry:
entry.message_count = 0
self._save()
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
"""Force reset a session, creating a new session ID."""
db_end_session_id = None
@@ -849,6 +899,7 @@ class SessionStore:
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
message_count=0, # Fresh count after rotation (#326)
)
self._entries[session_key] = new_entry
@@ -908,6 +959,7 @@ class SessionStore:
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
message_count=0, # Fresh count after rotation (#326)
)
self._entries[session_key] = new_entry
@@ -966,6 +1018,16 @@ class SessionStore:
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "a", encoding="utf-8") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n")
# Increment message count for marathon session tracking (#326)
# Skip counting session_meta entries (tool defs, metadata)
if message.get("role") != "session_meta":
with self._lock:
for entry in self._entries.values():
if entry.session_id == session_id:
entry.message_count += 1
self._save()
break
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Replace the entire transcript for a session with new messages.

View File

@@ -0,0 +1,184 @@
"""Tests for marathon session limits (#326)."""
import pytest
from datetime import datetime
from pathlib import Path
from tempfile import mkdtemp
from gateway.config import GatewayConfig, Platform, SessionResetPolicy
from gateway.session import SessionEntry, SessionSource, SessionStore
def _source(platform=Platform.LOCAL, chat_id="test"):
return SessionSource(platform=platform, chat_id=chat_id, chat_type="dm", user_id="u1")
def _store(max_messages=200, mode="both"):
cfg = GatewayConfig()
cfg.default_reset_policy = SessionResetPolicy(mode=mode, max_messages=max_messages)
return SessionStore(Path(mkdtemp()), cfg)
class TestSessionResetPolicyMaxMessages:
def test_default(self):
assert SessionResetPolicy().max_messages == 200
def test_custom(self):
assert SessionResetPolicy(max_messages=500).max_messages == 500
def test_unlimited(self):
assert SessionResetPolicy(max_messages=0).max_messages == 0
def test_to_dict(self):
d = SessionResetPolicy(max_messages=300).to_dict()
assert d["max_messages"] == 300
def test_from_dict(self):
p = SessionResetPolicy.from_dict({"max_messages": 150})
assert p.max_messages == 150
def test_from_dict_default(self):
assert SessionResetPolicy.from_dict({}).max_messages == 200
class TestSessionEntryMessageCount:
def test_default(self):
e = SessionEntry(session_key="k", session_id="s", created_at=datetime.now(), updated_at=datetime.now())
assert e.message_count == 0
def test_to_dict(self):
e = SessionEntry(session_key="k", session_id="s", created_at=datetime.now(), updated_at=datetime.now(), message_count=42)
assert e.to_dict()["message_count"] == 42
def test_from_dict(self):
e = SessionEntry.from_dict({"session_key": "k", "session_id": "s", "created_at": "2026-01-01T00:00:00", "updated_at": "2026-01-01T00:00:00", "message_count": 99})
assert e.message_count == 99
class TestShouldResetMessageLimit:
def test_at_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
assert s._should_reset(e, src) == "message_limit"
def test_over_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 250
assert s._should_reset(e, src) == "message_limit"
def test_below_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 100
assert s._should_reset(e, src) is None
def test_unlimited(self):
s = _store(max_messages=0, mode="none")
src = _source()
e = s.get_or_create_session(src)
e.message_count = 9999
assert s._should_reset(e, src) is None
def test_custom_limit(self):
s = _store(max_messages=50)
src = _source()
e = s.get_or_create_session(src)
e.message_count = 50
assert s._should_reset(e, src) == "message_limit"
def test_just_under(self):
s = _store(max_messages=50)
src = _source()
e = s.get_or_create_session(src)
e.message_count = 49
assert s._should_reset(e, src) is None
class TestAppendIncrementsCount:
def test_user_message(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "user", "content": "hi"})
e = s.get_or_create_session(src)
assert e.message_count == 1
def test_assistant_message(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "user", "content": "hi"})
s.append_to_transcript(e.session_id, {"role": "assistant", "content": "hello"})
e = s.get_or_create_session(src)
assert e.message_count == 2
def test_meta_not_counted(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "session_meta", "tools": []})
e = s.get_or_create_session(src)
assert e.message_count == 0
class TestGetMessageLimitInfo:
def test_at_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
info = s.get_message_limit_info(e.session_key)
assert info["at_limit"] is True
assert info["near_limit"] is True
assert info["remaining"] == 0
def test_near_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 180
info = s.get_message_limit_info(e.session_key)
assert info["near_limit"] is True
assert info["at_limit"] is False
assert info["remaining"] == 20
def test_well_below(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 50
info = s.get_message_limit_info(e.session_key)
assert info["near_limit"] is False
assert info["at_limit"] is False
def test_unknown(self):
s = _store()
info = s.get_message_limit_info("nonexistent")
assert info["at_limit"] is False
class TestResetMessageCount:
def test_reset(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 150
s.reset_message_count(e.session_key)
assert s.get_message_limit_info(e.session_key)["message_count"] == 0
class TestSessionRotation:
def test_fresh_count_after_reset(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
new = s.reset_session(e.session_key)
assert new is not None
assert new.message_count == 0
assert new.session_id != e.session_id

View File

@@ -1,316 +0,0 @@
"""Tests for the Human Confirmation Daemon."""
import asyncio
import json
import time
from unittest.mock import patch, MagicMock
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
# We import after the fixtures to avoid aiohttp import issues in test envs
try:
from aiohttp import web
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
if AIOHTTP_AVAILABLE:
from daemon.confirmation_server import (
ConfirmationServer,
ConfirmationRequest,
AuditLog,
RateLimiter,
Whitelist,
HIGH_RISK_ACTIONS,
DEFAULT_RATE_LIMIT,
RATE_LIMIT_WINDOW,
)
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestRateLimiter:
"""Unit tests for the RateLimiter."""
def test_allows_within_limit(self):
rl = RateLimiter(max_per_window=3, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is True
def test_blocks_over_limit(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
def test_remaining_count(self):
rl = RateLimiter(max_per_window=5, window=60)
assert rl.remaining("deploy") == 5
rl.check("deploy")
assert rl.remaining("deploy") == 4
def test_separate_actions_independent(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
assert rl.check("shutdown") is True # different action
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestWhitelist:
"""Unit tests for the Whitelist."""
def test_global_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy") is False
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
def test_session_scoped_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy", "session1") is False
wl.add("deploy", "session1")
assert wl.is_whitelisted("deploy", "session1") is True
assert wl.is_whitelisted("deploy", "session2") is False
def test_remove(self):
wl = Whitelist()
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
wl.remove("deploy")
assert wl.is_whitelisted("deploy") is False
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestAuditLog:
"""Unit tests for the AuditLog."""
def test_log_and_retrieve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-123",
action="deploy_production",
description="Deploy v2.0 to prod",
requester="timmy",
)
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["id"] == "test-123"
assert entries[0]["action"] == "deploy_production"
assert entries[0]["status"] == "pending"
log.close()
def test_update_on_resolve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-456",
action="delete_data",
description="Purge old records",
)
log.log(req)
# Resolve
req.status = "approved"
req.resolved_by = "alexander"
req.resolved_at = time.time()
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["status"] == "approved"
assert entries[0]["resolved_by"] == "alexander"
log.close()
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationRequest:
"""Unit tests for the data model."""
def test_to_dict(self):
req = ConfirmationRequest(
id="abc123",
action="deploy_production",
description="Ship it",
details={"version": "2.0"},
)
d = req.to_dict()
assert d["id"] == "abc123"
assert d["status"] == "pending"
assert d["created_at_iso"] is not None
assert d["resolved_at_iso"] is None
assert d["details"]["version"] == "2.0"
# ---------------------------------------------------------------------------
# Integration tests (HTTP)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationHTTP(AioHTTPTestCase):
"""Full HTTP integration tests for the ConfirmationServer."""
async def get_application(self):
# Suppress notification during tests
with patch("daemon.confirmation_server._notify_human", return_value=None):
server = ConfirmationServer(
host="127.0.0.1",
port=6000,
db_path=":memory:",
)
self._server = server
return server._build_app()
@unittest_run_loop
async def test_health(self):
resp = await self.client.request("GET", "/health")
assert resp.status == 200
data = await resp.json()
assert data["status"] == "ok"
@unittest_run_loop
async def test_submit_confirmation(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Deploy v2.0 to production",
"requester": "timmy",
"session_key": "test-session",
})
assert resp.status == 202
data = await resp.json()
assert data["status"] == "pending"
assert data["action"] == "deploy_production"
assert "id" in data
@unittest_run_loop
async def test_submit_missing_action(self):
resp = await self.client.request("POST", "/confirm", json={
"description": "Something",
})
assert resp.status == 400
@unittest_run_loop
async def test_submit_missing_description(self):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
})
assert resp.status == 400
@unittest_run_loop
async def test_approve_flow(self):
# Submit
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
assert submit_resp.status == 202
submit_data = await submit_resp.json()
req_id = submit_data["id"]
# Approve
approve_resp = await self.client.request(
"POST", f"/confirm/{req_id}/approve",
json={"approver": "alexander"},
)
assert approve_resp.status == 200
approve_data = await approve_resp.json()
assert approve_data["status"] == "approved"
assert approve_data["resolved_by"] == "alexander"
# Check status
status_resp = await self.client.request("GET", f"/confirm/{req_id}")
assert status_resp.status == 200
status_data = await status_resp.json()
assert status_data["status"] == "approved"
@unittest_run_loop
async def test_deny_flow(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "delete_data",
"description": "Wipe everything",
})
req_id = (await submit_resp.json())["id"]
deny_resp = await self.client.request(
"POST", f"/confirm/{req_id}/deny",
json={"denier": "alexander", "reason": "Too risky"},
)
assert deny_resp.status == 200
deny_data = await deny_resp.json()
assert deny_data["status"] == "denied"
assert deny_data["reason"] == "Too risky"
@unittest_run_loop
async def test_double_approve_returns_409(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
req_id = (await submit_resp.json())["id"]
await self.client.request(f"POST", f"/confirm/{req_id}/approve")
resp2 = await self.client.request(f"POST", f"/confirm/{req_id}/approve")
assert resp2.status == 409
@unittest_run_loop
async def test_not_found(self):
resp = await self.client.request("GET", "/confirm/nonexistent")
assert resp.status == 404
@unittest_run_loop
async def test_audit_log(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
resp = await self.client.request("GET", "/audit")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@unittest_run_loop
async def test_rate_limit(self):
# Exhaust rate limit (default is 10)
with patch("daemon.confirmation_server._notify_human", return_value=None):
for i in range(10):
await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": f"Request {i}",
})
# 11th should be rate-limited
resp = await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": "Over the limit",
})
assert resp.status == 429
@unittest_run_loop
async def test_whitelist_auto_approves(self):
self._server._whitelist.add("safe_action")
resp = await self.client.request("POST", "/confirm", json={
"action": "safe_action",
"description": "This is whitelisted",
})
assert resp.status == 200
data = await resp.json()
assert data["status"] == "auto_approved"