Compare commits

..

1 Commits

Author SHA1 Message Date
8400381a0d fix: persist token counts from gateway to SessionEntry and SQLite (#316)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 21s
The gateway's _run_agent returns input_tokens/output_tokens in its
result dict, but these were never stored to SessionEntry or the SQLite
session DB. Every session showed zero token counts.

Changes:
- gateway/session.py: Extend update_session() to accept and persist
  input_tokens, output_tokens, total_tokens, estimated_cost_usd
- gateway/run.py: Pass agent result token totals to update_session()
  and call set_token_counts(absolute=True) on _session_db after
  every conversation turn
- tests/test_token_tracking_persistence.py: Regression tests for
  SessionEntry serialization and agent result token extraction

Closes #316
2026-04-13 17:38:55 -04:00
6 changed files with 148 additions and 986 deletions

View File

@@ -1,5 +0,0 @@
"""Hermes daemon services — long-running background processes."""
from daemon.confirmation_server import ConfirmationServer
__all__ = ["ConfirmationServer"]

View File

@@ -1,664 +0,0 @@
"""Human Confirmation Daemon — route high-risk actions through human review.
HTTP server on port 6000 that intercepts dangerous operations and holds them
until a human approves or denies. Integrates with the existing approval
system (tools/approval.py) and notifies humans via Telegram/Discord/CLI.
Endpoints:
POST /confirm — submit a high-risk action for review
POST /confirm/{id}/approve — approve a pending confirmation
POST /confirm/{id}/deny — deny a pending confirmation
GET /confirm/{id} — check status of a confirmation
GET /audit — recent audit log entries
GET /health — liveness probe
Every decision is logged to SQLite for audit.
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import sqlite3
import threading
import time
import uuid
from dataclasses import dataclass, field, asdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
logger = logging.getLogger(__name__)
try:
from aiohttp import web
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
web = None # type: ignore[assignment]
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DEFAULT_HOST = "127.0.0.1"
DEFAULT_PORT = 6000
# Actions that always require human confirmation (not bypassable)
HIGH_RISK_ACTIONS = {
"deploy_production",
"delete_data",
"transfer_funds",
"modify_permissions",
"shutdown_service",
"wipe_database",
"exec_remote",
"publish_package",
"rotate_keys",
"migrate_database",
}
# Default rate limits: max N confirmations per action type per window
DEFAULT_RATE_LIMIT = 10 # max confirmations per action type
RATE_LIMIT_WINDOW = 3600 # 1 hour in seconds
# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------
@dataclass
class ConfirmationRequest:
"""A single pending or resolved confirmation."""
id: str
action: str
description: str
details: Dict[str, Any] = field(default_factory=dict)
requester: str = "" # agent or user who requested
session_key: str = ""
status: str = "pending" # pending | approved | denied | expired
resolved_by: str = ""
resolved_at: Optional[float] = None
created_at: float = field(default_factory=time.time)
timeout_seconds: int = 300 # 5 min default
def to_dict(self) -> dict:
d = asdict(self)
d["created_at_iso"] = _ts_to_iso(d["created_at"])
d["resolved_at_iso"] = _ts_to_iso(d["resolved_at"]) if d["resolved_at"] else None
return d
def _ts_to_iso(ts: Optional[float]) -> Optional[str]:
if ts is None:
return None
return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat()
# ---------------------------------------------------------------------------
# Audit log (SQLite)
# ---------------------------------------------------------------------------
class AuditLog:
"""SQLite-backed audit log for all confirmation decisions."""
def __init__(self, db_path: Optional[str] = None):
if db_path is None:
home = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
home.mkdir(parents=True, exist_ok=True)
db_path = str(home / "confirmation_audit.db")
self._conn = sqlite3.connect(db_path, check_same_thread=False)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("""
CREATE TABLE IF NOT EXISTS audit_log (
id TEXT PRIMARY KEY,
action TEXT NOT NULL,
description TEXT NOT NULL,
details TEXT NOT NULL DEFAULT '{}',
requester TEXT NOT NULL DEFAULT '',
session_key TEXT NOT NULL DEFAULT '',
status TEXT NOT NULL,
resolved_by TEXT NOT NULL DEFAULT '',
created_at REAL NOT NULL,
resolved_at REAL,
resolved_at_iso TEXT,
created_at_iso TEXT
)
""")
self._conn.commit()
def log(self, req: ConfirmationRequest) -> None:
d = req.to_dict()
self._conn.execute(
"""INSERT OR REPLACE INTO audit_log
(id, action, description, details, requester, session_key,
status, resolved_by, created_at, resolved_at,
resolved_at_iso, created_at_iso)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)""",
(
d["id"], d["action"], d["description"],
json.dumps(d["details"]), d["requester"], d["session_key"],
d["status"], d["resolved_by"],
d["created_at"], d["resolved_at"],
d["resolved_at_iso"], d["created_at_iso"],
),
)
self._conn.commit()
def recent(self, limit: int = 50) -> List[dict]:
rows = self._conn.execute(
"SELECT * FROM audit_log ORDER BY created_at DESC LIMIT ?", (limit,)
).fetchall()
cols = [d[0] for d in self._conn.description]
return [dict(zip(cols, row)) for row in rows]
def close(self) -> None:
self._conn.close()
# ---------------------------------------------------------------------------
# Rate limiter
# ---------------------------------------------------------------------------
class RateLimiter:
"""Simple sliding-window rate limiter per action type."""
def __init__(self, max_per_window: int = DEFAULT_RATE_LIMIT,
window: int = RATE_LIMIT_WINDOW):
self._max = max_per_window
self._window = window
self._timestamps: Dict[str, List[float]] = {} # action -> [ts, ...]
self._lock = threading.Lock()
def check(self, action: str) -> bool:
"""Return True if the action is within rate limits."""
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
# Prune expired
timestamps = [t for t in timestamps if now - t < self._window]
self._timestamps[action] = timestamps
if len(timestamps) >= self._max:
return False
timestamps.append(now)
return True
def remaining(self, action: str) -> int:
now = time.time()
with self._lock:
timestamps = self._timestamps.get(action, [])
timestamps = [t for t in timestamps if now - t < self._window]
return max(0, self._max - len(timestamps))
# ---------------------------------------------------------------------------
# Notification dispatcher
# ---------------------------------------------------------------------------
async def _notify_human(request: ConfirmationRequest) -> None:
"""Send a notification about a pending confirmation to humans.
Tries Telegram first, then Discord, then falls back to log warning.
Uses the existing send_message infrastructure.
"""
msg = (
f"\U0001f514 Confirmation Required\n"
f"Action: {request.action}\n"
f"Description: {request.description}\n"
f"Requester: {request.requester}\n"
f"ID: {request.id}\n\n"
f"Approve: POST /confirm/{request.id}/approve\n"
f"Deny: POST /confirm/{request.id}/deny"
)
sent = False
# Try Telegram
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "telegram",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Telegram", request.id)
except Exception as e:
logger.debug("Telegram notify failed: %s", e)
# Try Discord if Telegram failed
if not sent:
try:
from tools.send_message_tool import _handle_send
result = _handle_send({
"target": "discord",
"message": msg,
})
result_dict = json.loads(result) if isinstance(result, str) else result
if "error" not in result_dict:
sent = True
logger.info("Confirmation %s: notified via Discord", request.id)
except Exception as e:
logger.debug("Discord notify failed: %s", e)
if not sent:
logger.warning(
"Confirmation %s: no messaging channel available. "
"Action '%s' requires human review -- check /confirm/%s",
request.id, request.action, request.id,
)
# ---------------------------------------------------------------------------
# Whitelist manager
# ---------------------------------------------------------------------------
class Whitelist:
"""Configurable whitelist for actions that skip confirmation."""
def __init__(self):
self._allowed: Dict[str, set] = {} # session_key -> {action, ...}
self._global_allowed: set = set()
self._lock = threading.Lock()
def is_whitelisted(self, action: str, session_key: str = "") -> bool:
with self._lock:
if action in self._global_allowed:
return True
if session_key and action in self._allowed.get(session_key, set()):
return True
return False
def add(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.setdefault(session_key, set()).add(action)
else:
self._global_allowed.add(action)
def remove(self, action: str, session_key: str = "") -> None:
with self._lock:
if session_key:
self._allowed.get(session_key, set()).discard(action)
else:
self._global_allowed.discard(action)
# ---------------------------------------------------------------------------
# Confirmation Server
# ---------------------------------------------------------------------------
class ConfirmationServer:
"""HTTP server for human confirmation of high-risk actions.
Usage:
server = ConfirmationServer()
await server.start() # blocks
# or:
server.start_background() # non-blocking
...
await server.stop()
"""
def __init__(self, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT,
db_path: Optional[str] = None,
rate_limit: int = DEFAULT_RATE_LIMIT):
if not AIOHTTP_AVAILABLE:
raise RuntimeError(
"aiohttp is required for the confirmation daemon. "
"Install with: pip install aiohttp"
)
self._host = host
self._port = port
self._audit = AuditLog(db_path)
self._rate_limiter = RateLimiter(max_per_window=rate_limit)
self._whitelist = Whitelist()
self._pending: Dict[str, ConfirmationRequest] = {}
self._lock = threading.Lock()
self._app: Optional[web.Application] = None
self._runner: Optional[web.AppRunner] = None
self._bg_thread: Optional[threading.Thread] = None
# --- Lifecycle ---
def _build_app(self) -> web.Application:
app = web.Application(client_max_size=1_048_576) # 1 MB
app["server"] = self
app.router.add_post("/confirm", self._handle_submit)
app.router.add_post("/confirm/{req_id}/approve", self._handle_approve)
app.router.add_post("/confirm/{req_id}/deny", self._handle_deny)
app.router.add_get("/confirm/{req_id}", self._handle_status)
app.router.add_get("/audit", self._handle_audit)
app.router.add_get("/health", self._handle_health)
return app
async def start(self) -> None:
"""Start the server and block until stopped."""
self._app = self._build_app()
self._runner = web.AppRunner(self._app)
await self._runner.setup()
site = web.TCPSite(self._runner, self._host, self._port)
await site.start()
logger.info(
"Confirmation daemon listening on http://%s:%s",
self._host, self._port,
)
# Run until cancelled
try:
await asyncio.Event().wait()
except asyncio.CancelledError:
pass
finally:
await self.stop()
def start_background(self) -> None:
"""Start the server in a background thread (non-blocking)."""
def _run():
asyncio.run(self.start())
self._bg_thread = threading.Thread(target=_run, daemon=True)
self._bg_thread.start()
async def stop(self) -> None:
"""Gracefully stop the server."""
if self._runner:
await self._runner.cleanup()
self._runner = None
self._audit.close()
logger.info("Confirmation daemon stopped")
# --- Internal helpers ---
def _get_request(self, req_id: str) -> Optional[ConfirmationRequest]:
with self._lock:
return self._pending.get(req_id)
def _expire_old_requests(self) -> int:
"""Mark expired requests. Returns count expired."""
now = time.time()
expired = 0
with self._lock:
for req in list(self._pending.values()):
if req.status == "pending" and (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
expired += 1
return expired
# --- HTTP handlers ---
async def _handle_submit(self, request: web.Request) -> web.Response:
"""POST /confirm -- submit a new confirmation request."""
try:
body = await request.json()
except Exception:
return web.json_response(
{"error": "Invalid JSON body"}, status=400
)
action = (body.get("action") or "").strip()
description = (body.get("description") or "").strip()
details = body.get("details") or {}
requester = (body.get("requester") or "agent").strip()
session_key = (body.get("session_key") or "").strip()
timeout = body.get("timeout_seconds", 300)
if not action:
return web.json_response(
{"error": "Field 'action' is required"}, status=400
)
if not description:
return web.json_response(
{"error": "Field 'description' is required"}, status=400
)
# Whitelist check
if self._whitelist.is_whitelisted(action, session_key):
auto_id = str(uuid.uuid4())[:8]
return web.json_response({
"id": auto_id,
"action": action,
"status": "auto_approved",
"message": f"Action '{action}' is whitelisted for this session",
})
# Rate limit check
if not self._rate_limiter.check(action):
remaining = self._rate_limiter.remaining(action)
return web.json_response({
"error": f"Rate limit exceeded for action '{action}'",
"remaining": remaining,
"window_seconds": RATE_LIMIT_WINDOW,
}, status=429)
# Enforce timeout bounds
try:
timeout = max(30, min(int(timeout), 3600))
except (ValueError, TypeError):
timeout = 300
# Create request
req = ConfirmationRequest(
id=str(uuid.uuid4())[:12],
action=action,
description=description,
details=details,
requester=requester,
session_key=session_key,
timeout_seconds=timeout,
)
with self._lock:
self._pending[req.id] = req
# Audit log
self._audit.log(req)
# Notify humans (fire-and-forget)
asyncio.create_task(_notify_human(req))
logger.info(
"Confirmation %s submitted: action=%s requester=%s",
req.id, action, requester,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": req.status,
"timeout_seconds": req.timeout_seconds,
"message": "Confirmation pending. Awaiting human review.",
"approve_url": f"/confirm/{req.id}/approve",
"deny_url": f"/confirm/{req.id}/deny",
}, status=202)
async def _handle_approve(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/approve -- approve a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
# Parse optional approver identity
try:
body = await request.json()
approver = (body.get("approver") or "api").strip()
except Exception:
approver = "api"
req.status = "approved"
req.resolved_by = approver
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s APPROVED by %s (action=%s)",
req_id, approver, req.action,
)
return web.json_response({
"id": req.id,
"action": req.action,
"status": "approved",
"resolved_by": approver,
"resolved_at": req.resolved_at,
})
async def _handle_deny(self, request: web.Request) -> web.Response:
"""POST /confirm/{id}/deny -- deny a pending confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
if req.status != "pending":
return web.json_response({
"error": f"Confirmation '{req_id}' already resolved",
"status": req.status,
}, status=409)
try:
body = await request.json()
denier = (body.get("denier") or "api").strip()
reason = (body.get("reason") or "").strip()
except Exception:
denier = "api"
reason = ""
req.status = "denied"
req.resolved_by = denier
req.resolved_at = time.time()
# Audit log
self._audit.log(req)
logger.info(
"Confirmation %s DENIED by %s (action=%s, reason=%s)",
req_id, denier, req.action, reason,
)
resp = {
"id": req.id,
"action": req.action,
"status": "denied",
"resolved_by": denier,
"resolved_at": req.resolved_at,
}
if reason:
resp["reason"] = reason
return web.json_response(resp)
async def _handle_status(self, request: web.Request) -> web.Response:
"""GET /confirm/{id} -- check status of a confirmation."""
req_id = request.match_info["req_id"]
req = self._get_request(req_id)
if req is None:
return web.json_response(
{"error": f"Confirmation '{req_id}' not found"}, status=404
)
# Check for expiration
if req.status == "pending":
now = time.time()
if (now - req.created_at) > req.timeout_seconds:
req.status = "expired"
req.resolved_at = now
req.resolved_by = "system:timeout"
self._audit.log(req)
return web.json_response(req.to_dict())
async def _handle_audit(self, request: web.Request) -> web.Response:
"""GET /audit -- recent audit log entries."""
try:
limit = int(request.query.get("limit", "50"))
limit = max(1, min(limit, 500))
except (ValueError, TypeError):
limit = 50
entries = self._audit.recent(limit)
return web.json_response({
"count": len(entries),
"entries": entries,
})
async def _handle_health(self, request: web.Request) -> web.Response:
"""GET /health -- liveness probe."""
return web.json_response({
"status": "ok",
"pending_count": len(self._pending),
"timestamp": time.time(),
})
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def main():
"""Run the confirmation daemon as a standalone process."""
import argparse
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s",
)
parser = argparse.ArgumentParser(
description="Hermes Human Confirmation Daemon"
)
parser.add_argument(
"--host", default=os.getenv("CONFIRMATION_HOST", DEFAULT_HOST),
help="Bind address (default: 127.0.0.1)",
)
parser.add_argument(
"--port", type=int,
default=int(os.getenv("CONFIRMATION_PORT", DEFAULT_PORT)),
help="Bind port (default: 6000)",
)
parser.add_argument(
"--db-path", default=None,
help="SQLite database path (default: ~/.hermes/confirmation_audit.db)",
)
parser.add_argument(
"--rate-limit", type=int,
default=int(os.getenv("CONFIRMATION_RATE_LIMIT", DEFAULT_RATE_LIMIT)),
help="Max confirmations per action per hour (default: 10)",
)
args = parser.parse_args()
if not AIOHTTP_AVAILABLE:
print("ERROR: aiohttp is required. Install with: pip install aiohttp")
raise SystemExit(1)
server = ConfirmationServer(
host=args.host,
port=args.port,
db_path=args.db_path,
rate_limit=args.rate_limit,
)
print(f"Starting Confirmation Daemon on http://{args.host}:{args.port}")
asyncio.run(server.start())
if __name__ == "__main__":
main()

View File

@@ -3067,12 +3067,40 @@ class GatewayRunner:
# Token counts and model are now persisted by the agent directly.
# Keep only last_prompt_tokens here for context-window tracking and
# compression decisions.
# compression decisions. Also persist input/output token totals
# so the SessionEntry (sessions.json) and SQLite reflect actual usage.
_input_total = agent_result.get("input_tokens", 0) or 0
_output_total = agent_result.get("output_tokens", 0) or 0
_total_tokens = agent_result.get("total_tokens", 0) or 0
_cost_usd = agent_result.get("estimated_cost_usd")
self.session_store.update_session(
session_entry.session_key,
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
input_tokens=_input_total,
output_tokens=_output_total,
total_tokens=_total_tokens,
estimated_cost_usd=_cost_usd,
)
# Persist token totals to SQLite so /insights sees real data.
# Use absolute=true because the agent's session_*_tokens already
# reflect the running total for this conversation turn.
if self._session_db:
try:
_eff_sid = agent_result.get("session_id") or session_entry.session_id
self._session_db.set_token_counts(
_eff_sid,
input_tokens=_input_total,
output_tokens=_output_total,
cache_read_tokens=agent_result.get("cache_read_tokens", 0) or 0,
cache_write_tokens=agent_result.get("cache_write_tokens", 0) or 0,
reasoning_tokens=agent_result.get("reasoning_tokens", 0) or 0,
estimated_cost_usd=_cost_usd,
model=_resolved_model,
)
except Exception:
pass # never block delivery
# Auto voice reply: send TTS audio before the text response
_already_sent = bool(agent_result.get("already_sent"))
if self._should_send_voice_reply(event, response, agent_messages, already_sent=_already_sent):

View File

@@ -810,6 +810,10 @@ class SessionStore:
self,
session_key: str,
last_prompt_tokens: int = None,
input_tokens: int = None,
output_tokens: int = None,
total_tokens: int = None,
estimated_cost_usd: float = None,
) -> None:
"""Update lightweight session metadata after an interaction."""
with self._lock:
@@ -820,6 +824,14 @@ class SessionStore:
entry.updated_at = _now()
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
if input_tokens is not None:
entry.input_tokens = input_tokens
if output_tokens is not None:
entry.output_tokens = output_tokens
if total_tokens is not None:
entry.total_tokens = total_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd = estimated_cost_usd
self._save()
def reset_session(self, session_key: str) -> Optional[SessionEntry]:

View File

@@ -1,316 +0,0 @@
"""Tests for the Human Confirmation Daemon."""
import asyncio
import json
import time
from unittest.mock import patch, MagicMock
import pytest
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
# We import after the fixtures to avoid aiohttp import issues in test envs
try:
from aiohttp import web
from aiohttp.test_utils import AioHTTPTestCase, unittest_run_loop
AIOHTTP_AVAILABLE = True
except ImportError:
AIOHTTP_AVAILABLE = False
if AIOHTTP_AVAILABLE:
from daemon.confirmation_server import (
ConfirmationServer,
ConfirmationRequest,
AuditLog,
RateLimiter,
Whitelist,
HIGH_RISK_ACTIONS,
DEFAULT_RATE_LIMIT,
RATE_LIMIT_WINDOW,
)
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestRateLimiter:
"""Unit tests for the RateLimiter."""
def test_allows_within_limit(self):
rl = RateLimiter(max_per_window=3, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is True
def test_blocks_over_limit(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
def test_remaining_count(self):
rl = RateLimiter(max_per_window=5, window=60)
assert rl.remaining("deploy") == 5
rl.check("deploy")
assert rl.remaining("deploy") == 4
def test_separate_actions_independent(self):
rl = RateLimiter(max_per_window=2, window=60)
assert rl.check("deploy") is True
assert rl.check("deploy") is True
assert rl.check("deploy") is False
assert rl.check("shutdown") is True # different action
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestWhitelist:
"""Unit tests for the Whitelist."""
def test_global_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy") is False
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
def test_session_scoped_whitelist(self):
wl = Whitelist()
assert wl.is_whitelisted("deploy", "session1") is False
wl.add("deploy", "session1")
assert wl.is_whitelisted("deploy", "session1") is True
assert wl.is_whitelisted("deploy", "session2") is False
def test_remove(self):
wl = Whitelist()
wl.add("deploy")
assert wl.is_whitelisted("deploy") is True
wl.remove("deploy")
assert wl.is_whitelisted("deploy") is False
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestAuditLog:
"""Unit tests for the AuditLog."""
def test_log_and_retrieve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-123",
action="deploy_production",
description="Deploy v2.0 to prod",
requester="timmy",
)
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["id"] == "test-123"
assert entries[0]["action"] == "deploy_production"
assert entries[0]["status"] == "pending"
log.close()
def test_update_on_resolve(self, tmp_path):
db = str(tmp_path / "test_audit.db")
log = AuditLog(db_path=db)
req = ConfirmationRequest(
id="test-456",
action="delete_data",
description="Purge old records",
)
log.log(req)
# Resolve
req.status = "approved"
req.resolved_by = "alexander"
req.resolved_at = time.time()
log.log(req)
entries = log.recent(limit=10)
assert len(entries) == 1
assert entries[0]["status"] == "approved"
assert entries[0]["resolved_by"] == "alexander"
log.close()
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationRequest:
"""Unit tests for the data model."""
def test_to_dict(self):
req = ConfirmationRequest(
id="abc123",
action="deploy_production",
description="Ship it",
details={"version": "2.0"},
)
d = req.to_dict()
assert d["id"] == "abc123"
assert d["status"] == "pending"
assert d["created_at_iso"] is not None
assert d["resolved_at_iso"] is None
assert d["details"]["version"] == "2.0"
# ---------------------------------------------------------------------------
# Integration tests (HTTP)
# ---------------------------------------------------------------------------
@pytest.mark.skipif(not AIOHTTP_AVAILABLE, reason="aiohttp not installed")
class TestConfirmationHTTP(AioHTTPTestCase):
"""Full HTTP integration tests for the ConfirmationServer."""
async def get_application(self):
# Suppress notification during tests
with patch("daemon.confirmation_server._notify_human", return_value=None):
server = ConfirmationServer(
host="127.0.0.1",
port=6000,
db_path=":memory:",
)
self._server = server
return server._build_app()
@unittest_run_loop
async def test_health(self):
resp = await self.client.request("GET", "/health")
assert resp.status == 200
data = await resp.json()
assert data["status"] == "ok"
@unittest_run_loop
async def test_submit_confirmation(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Deploy v2.0 to production",
"requester": "timmy",
"session_key": "test-session",
})
assert resp.status == 202
data = await resp.json()
assert data["status"] == "pending"
assert data["action"] == "deploy_production"
assert "id" in data
@unittest_run_loop
async def test_submit_missing_action(self):
resp = await self.client.request("POST", "/confirm", json={
"description": "Something",
})
assert resp.status == 400
@unittest_run_loop
async def test_submit_missing_description(self):
resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
})
assert resp.status == 400
@unittest_run_loop
async def test_approve_flow(self):
# Submit
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
assert submit_resp.status == 202
submit_data = await submit_resp.json()
req_id = submit_data["id"]
# Approve
approve_resp = await self.client.request(
"POST", f"/confirm/{req_id}/approve",
json={"approver": "alexander"},
)
assert approve_resp.status == 200
approve_data = await approve_resp.json()
assert approve_data["status"] == "approved"
assert approve_data["resolved_by"] == "alexander"
# Check status
status_resp = await self.client.request("GET", f"/confirm/{req_id}")
assert status_resp.status == 200
status_data = await status_resp.json()
assert status_data["status"] == "approved"
@unittest_run_loop
async def test_deny_flow(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "delete_data",
"description": "Wipe everything",
})
req_id = (await submit_resp.json())["id"]
deny_resp = await self.client.request(
"POST", f"/confirm/{req_id}/deny",
json={"denier": "alexander", "reason": "Too risky"},
)
assert deny_resp.status == 200
deny_data = await deny_resp.json()
assert deny_data["status"] == "denied"
assert deny_data["reason"] == "Too risky"
@unittest_run_loop
async def test_double_approve_returns_409(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
submit_resp = await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
req_id = (await submit_resp.json())["id"]
await self.client.request(f"POST", f"/confirm/{req_id}/approve")
resp2 = await self.client.request(f"POST", f"/confirm/{req_id}/approve")
assert resp2.status == 409
@unittest_run_loop
async def test_not_found(self):
resp = await self.client.request("GET", "/confirm/nonexistent")
assert resp.status == 404
@unittest_run_loop
async def test_audit_log(self):
with patch("daemon.confirmation_server._notify_human", return_value=None):
await self.client.request("POST", "/confirm", json={
"action": "deploy_production",
"description": "Ship it",
})
resp = await self.client.request("GET", "/audit")
assert resp.status == 200
data = await resp.json()
assert data["count"] >= 1
@unittest_run_loop
async def test_rate_limit(self):
# Exhaust rate limit (default is 10)
with patch("daemon.confirmation_server._notify_human", return_value=None):
for i in range(10):
await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": f"Request {i}",
})
# 11th should be rate-limited
resp = await self.client.request("POST", "/confirm", json={
"action": "test_rate_action",
"description": "Over the limit",
})
assert resp.status == 429
@unittest_run_loop
async def test_whitelist_auto_approves(self):
self._server._whitelist.add("safe_action")
resp = await self.client.request("POST", "/confirm", json={
"action": "safe_action",
"description": "This is whitelisted",
})
assert resp.status == 200
data = await resp.json()
assert data["status"] == "auto_approved"

View File

@@ -0,0 +1,107 @@
"""Tests for gateway token count persistence to SessionEntry and SessionDB.
Regression test for #316 — token tracking all zeros. The gateway must
propagate input_tokens / output_tokens from the agent result to both the
SessionEntry (sessions.json) and the SQLite session DB.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock
import pytest
from gateway.session import SessionEntry
class TestUpdateSessionTokenFields:
"""Verify SessionEntry token fields are updated and serialized correctly."""
def test_session_entry_to_dict_includes_tokens(self):
entry = SessionEntry(
session_key="tg:123",
session_id="sid-1",
created_at=datetime.now(),
updated_at=datetime.now(),
input_tokens=1000,
output_tokens=500,
total_tokens=1500,
estimated_cost_usd=0.05,
)
d = entry.to_dict()
assert d["input_tokens"] == 1000
assert d["output_tokens"] == 500
assert d["total_tokens"] == 1500
assert d["estimated_cost_usd"] == 0.05
def test_session_entry_from_dict_restores_tokens(self):
now = datetime.now().isoformat()
data = {
"session_key": "tg:123",
"session_id": "sid-1",
"created_at": now,
"updated_at": now,
"input_tokens": 42,
"output_tokens": 21,
"total_tokens": 63,
"estimated_cost_usd": 0.001,
}
entry = SessionEntry.from_dict(data)
assert entry.input_tokens == 42
assert entry.output_tokens == 21
assert entry.total_tokens == 63
assert entry.estimated_cost_usd == 0.001
def test_session_entry_roundtrip_preserves_tokens(self):
"""to_dict -> from_dict must preserve all token fields."""
entry = SessionEntry(
session_key="cron:job7",
session_id="sid-7",
created_at=datetime.now(),
updated_at=datetime.now(),
input_tokens=9999,
output_tokens=1234,
total_tokens=11233,
cache_read_tokens=500,
cache_write_tokens=100,
estimated_cost_usd=0.42,
)
restored = SessionEntry.from_dict(entry.to_dict())
assert restored.input_tokens == 9999
assert restored.output_tokens == 1234
assert restored.total_tokens == 11233
assert restored.cache_read_tokens == 500
assert restored.cache_write_tokens == 100
assert restored.estimated_cost_usd == 0.42
class TestAgentResultTokenExtraction:
"""Verify the gateway extracts token counts from agent_result correctly."""
def test_agent_result_has_expected_keys(self):
"""Simulate what _run_agent returns and verify all token keys exist."""
result = {
"final_response": "hello",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
"cache_read_tokens": 10,
"cache_write_tokens": 5,
"reasoning_tokens": 0,
"estimated_cost_usd": 0.002,
"last_prompt_tokens": 100,
"model": "test-model",
"session_id": "test-session-123",
}
# These are the extractions the gateway performs
assert result.get("input_tokens", 0) or 0 == 100
assert result.get("output_tokens", 0) or 0 == 50
assert result.get("total_tokens", 0) or 0 == 150
assert result.get("estimated_cost_usd") == 0.002
def test_agent_result_zero_fallback(self):
"""When token keys are missing, defaults to 0."""
result = {"final_response": "ok"}
assert result.get("input_tokens", 0) or 0 == 0
assert result.get("output_tokens", 0) or 0 == 0
assert result.get("total_tokens", 0) or 0 == 0