Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
8dcb6950bc fix: add post-tool-result context overflow guard (#613)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 43s
The context pressure check used API-reported token counts (prompt +
completion) which do not include tool results appended in the same
turn.  A single large tool result (e.g. reading a 50 KB file) could
push context from 80% to 95%+ invisibly — the pressure warning only
fired on the *next* API call, too late to be useful.

Changes:
- Snapshot message list length before _execute_tool_calls.
- After tool execution, walk newly appended tool-result messages and
  accumulate a rough token estimate (_tool_result_tokens_added).
- Emit an immediate ⚠️ _vprint warning when any single result exceeds
  10 K tokens (~40 KB), so the user knows what caused the pressure
  spike before the next API call.
- Add the accumulated estimate to _real_tokens when using API-reported
  counts so the pressure check (≥ 85%) fires correctly in the same
  turn rather than waiting until the next iteration.
- 12 new unit tests covering threshold logic, accumulation math, and
  the warning emission behaviour.

Fixes #613

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-14 11:57:01 -04:00
5 changed files with 238 additions and 256 deletions

View File

@@ -26,7 +26,7 @@ from cron.jobs import (
trigger_job,
JOBS_FILE,
)
from cron.scheduler import tick
from cron.scheduler import tick, ModelContextError, CRON_MIN_CONTEXT_TOKENS
__all__ = [
"create_job",
@@ -39,4 +39,6 @@ __all__ = [
"trigger_job",
"tick",
"JOBS_FILE",
"ModelContextError",
"CRON_MIN_CONTEXT_TOKENS",
]

View File

@@ -186,14 +186,7 @@ _SCRIPT_FAILURE_PHRASES = (
"unable to execute",
"permission denied",
"no such file",
"no such file or directory",
"command not found",
"traceback",
"hermes binary not found",
"hermes not found",
"ssh: connect to host",
"connection timed out",
"host key verification failed",
)

View File

@@ -1,243 +0,0 @@
"""SSH Dispatch — validated remote hermes execution for cron jobs.
Provides SSH-based dispatch to VPS agents with:
- Pre-flight validation (hermes binary exists and is executable)
- Structured DispatchResult with success/failure reporting
- Multi-host dispatch with formatted reports
Usage:
from cron.ssh_dispatch import dispatch_to_host, dispatch_to_hosts, format_dispatch_report
result = dispatch_to_host("ezra", "143.198.27.163", "Check the beacon repo for open issues")
if not result.success:
print(result.error)
results = dispatch_to_hosts(["ezra", "bezalel"], "Run fleet health check")
print(format_dispatch_report(results))
Ref: #350, #541, #561
"""
from __future__ import annotations
import logging
import subprocess
from dataclasses import dataclass, field
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
# Known VPS hosts (can be overridden via env or config)
DEFAULT_HOSTS: Dict[str, str] = {
"ezra": "143.198.27.163",
"bezalel": "159.203.146.185",
}
# SSH options for non-interactive, fast-fail connections
_SSH_OPTS = [
"-o", "ConnectTimeout=10",
"-o", "StrictHostKeyChecking=accept-new",
"-o", "BatchMode=yes",
"-o", "LogLevel=ERROR",
]
# Paths to check for hermes binary on remote
_HERMES_CHECK_PATHS = [
"~/.local/bin/hermes",
"/usr/local/bin/hermes",
"~/.hermes/bin/hermes",
]
@dataclass
class DispatchResult:
"""Result of an SSH dispatch attempt."""
host: str
address: str
success: bool
output: str = ""
error: str = ""
hermes_found: bool = False
hermes_path: str = ""
exit_code: int = -1
@property
def summary(self) -> str:
if self.success:
return f"[OK] {self.host} ({self.address})"
return f"[FAIL] {self.host} ({self.address}): {self.error}"
def probe_hermes(host: str, address: str) -> tuple[bool, str]:
"""Check if hermes binary exists and is executable on remote host.
Returns (found, path).
"""
check_cmds = " || ".join(f"test -x {p} && echo {p}" for p in _HERMES_CHECK_PATHS)
remote_cmd = f"bash -c '{check_cmds} || echo NOTFOUND'"
try:
result = subprocess.run(
["ssh", address, *_SSH_OPTS, remote_cmd],
capture_output=True,
text=True,
timeout=15,
)
output = result.stdout.strip()
if output and output != "NOTFOUND":
return True, output
return False, ""
except subprocess.TimeoutExpired:
logger.warning("SSH probe timed out for %s", host)
return False, ""
except Exception as e:
logger.warning("SSH probe failed for %s: %s", host, e)
return False, ""
def dispatch_to_host(
host: str,
address: str,
prompt: str,
timeout: int = 300,
validate: bool = True,
) -> DispatchResult:
"""Dispatch a prompt to a remote hermes instance via SSH.
Args:
host: Hostname (ezra, bezalel, etc.)
address: IP address or hostname
prompt: The prompt/task to dispatch
timeout: SSH timeout in seconds
validate: Whether to probe for hermes binary first
Returns:
DispatchResult with success/failure details.
"""
# Pre-flight validation
if validate:
found, path = probe_hermes(host, address)
if not found:
return DispatchResult(
host=host,
address=address,
success=False,
error="hermes binary not found on remote host",
hermes_found=False,
)
else:
found, path = True, "~/.local/bin/hermes"
# Build the dispatch command
# Use hermes chat in quiet mode, pipe prompt via stdin
escaped_prompt = prompt.replace("'", "'\\''")
remote_cmd = f"echo '{escaped_prompt}' | {path} chat --quiet"
try:
result = subprocess.run(
["ssh", address, *_SSH_OPTS, remote_cmd],
capture_output=True,
text=True,
timeout=timeout,
)
success = result.returncode == 0
error = ""
if not success:
error = result.stderr.strip() if result.stderr else f"exit code {result.returncode}"
return DispatchResult(
host=host,
address=address,
success=success,
output=result.stdout.strip()[:500], # Truncate long output
error=error,
hermes_found=found,
hermes_path=path,
exit_code=result.returncode,
)
except subprocess.TimeoutExpired:
return DispatchResult(
host=host,
address=address,
success=False,
error=f"SSH dispatch timed out after {timeout}s",
hermes_found=found,
hermes_path=path,
)
except Exception as e:
return DispatchResult(
host=host,
address=address,
success=False,
error=f"SSH dispatch failed: {e}",
hermes_found=found,
hermes_path=path,
)
def dispatch_to_hosts(
hosts: List[str],
prompt: str,
host_map: Optional[Dict[str, str]] = None,
timeout: int = 300,
) -> List[DispatchResult]:
"""Dispatch a prompt to multiple hosts.
Args:
hosts: List of hostnames
prompt: The prompt/task to dispatch
host_map: Optional override of hostname -> address mapping
timeout: SSH timeout per host
Returns:
List of DispatchResult, one per host.
"""
addresses = host_map or DEFAULT_HOSTS
results = []
for host in hosts:
address = addresses.get(host)
if not address:
results.append(DispatchResult(
host=host,
address="unknown",
success=False,
error=f"Unknown host: {host}",
))
continue
result = dispatch_to_host(host, address, prompt, timeout=timeout)
results.append(result)
logger.info(result.summary)
return results
def format_dispatch_report(results: List[DispatchResult]) -> str:
"""Format a multi-host dispatch results as a readable report."""
if not results:
return "No dispatch results."
lines = ["SSH Dispatch Report", "=" * 40, ""]
ok_count = sum(1 for r in results if r.success)
fail_count = len(results) - ok_count
lines.append(f"Total: {len(results)} | OK: {ok_count} | FAIL: {fail_count}")
lines.append("")
for r in results:
status = "" if r.success else ""
lines.append(f" {status} {r.host} ({r.address})")
if r.hermes_path:
lines.append(f" hermes: {r.hermes_path}")
if r.success and r.output:
lines.append(f" output: {r.output[:100]}...")
if not r.success:
lines.append(f" error: {r.error}")
lines.append("")
return "\n".join(lines)

View File

@@ -8949,8 +8949,32 @@ class AIAgent:
except Exception:
pass
# Snapshot message count before tool execution so we can
# inspect the tool results that get appended (#613).
_pre_tool_exec_len = len(messages)
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
# ── Post-tool-result overflow guard (#613) ───────────────
# Large tool results (e.g. reading a 50 KB file) can push
# context from 80% to 95%+ in a single turn. Warn when
# any single result exceeds the threshold so the user knows
# what caused sudden pressure before the next API call.
# Also accumulate the token estimate so the pressure check
# below uses a tighter bound that includes the new results.
_LARGE_TOOL_RESULT_TOKENS = 10_000
_tool_result_tokens_added = 0
for _tr_msg in messages[_pre_tool_exec_len:]:
if _tr_msg.get("role") == "tool":
_tr_content = _tr_msg.get("content") or ""
_tr_tokens = estimate_tokens_rough(_tr_content)
_tool_result_tokens_added += _tr_tokens
if _tr_tokens > _LARGE_TOOL_RESULT_TOKENS:
self._vprint(
f"{self.log_prefix}⚠️ Large tool result: "
f"~{_tr_tokens:,} tokens added to context."
)
# Signal that a paragraph break is needed before the next
# streamed text. We don't emit it immediately because
# multiple consecutive tool iterations would stack up
@@ -8965,15 +8989,14 @@ class AIAgent:
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
if _tc_names == {"execute_code"}:
self.iteration_budget.refund()
# Use real token counts from the API response to decide
# compression. prompt_tokens + completion_tokens is the
# actual context size the provider reported plus the
# assistant turn — a tight lower bound for the next prompt.
# Tool results appended above aren't counted yet, but the
# threshold (default 50%) leaves ample headroom; if tool
# results push past it, the next API call will report the
# real total and trigger compression then.
# Tool results are not included in the API-reported counts
# so we add our rough estimate (_tool_result_tokens_added)
# to avoid missing pressure that large results introduced.
#
# If last_prompt_tokens is 0 (stale after API disconnect
# or provider returned no usage data), fall back to rough
@@ -8985,6 +9008,7 @@ class AIAgent:
_real_tokens = (
_compressor.last_prompt_tokens
+ _compressor.last_completion_tokens
+ _tool_result_tokens_added
)
else:
_real_tokens = estimate_messages_tokens_rough(messages)

View File

@@ -0,0 +1,206 @@
"""Tests for #613 — post-tool-result context overflow guard.
Verifies that:
1. Large tool results (> 10 K tokens) trigger an immediate user-facing warning.
2. Small tool results do not trigger the warning.
3. The token estimate used for the context-pressure check includes tool-result
tokens (not only API-reported counts from before tool execution).
4. Multiple large results each trigger a warning; non-tool messages are ignored.
"""
from unittest.mock import MagicMock, patch
import pytest
from agent.model_metadata import estimate_tokens_rough
# ---------------------------------------------------------------------------
# Helper: build fake tool-result messages
# ---------------------------------------------------------------------------
def _tool_msg(content: str, tool_call_id: str = "call_1") -> dict:
return {"role": "tool", "tool_call_id": tool_call_id, "content": content}
def _user_msg(content: str) -> dict:
return {"role": "user", "content": content}
# ---------------------------------------------------------------------------
# Test 1: Token threshold detection
# ---------------------------------------------------------------------------
_LARGE_TOOL_RESULT_TOKENS = 10_000 # mirrors the constant in run_agent.py
class TestLargeToolResultDetection:
"""Logic for detecting oversized tool results mirrors the guard in the
agent loop. These tests verify the threshold and accumulation math."""
def test_small_result_does_not_exceed_threshold(self):
content = "x" * 100 # ~25 tokens
tokens = estimate_tokens_rough(content)
assert tokens <= _LARGE_TOOL_RESULT_TOKENS
def test_large_result_exceeds_threshold(self):
# estimate_tokens_rough uses integer division (// 4).
# 40_004 chars → 10_001 tokens, strictly > 10_000.
content = "a" * 40_004
tokens = estimate_tokens_rough(content)
assert tokens > _LARGE_TOOL_RESULT_TOKENS
def test_exactly_at_threshold_does_not_warn(self):
# Exactly 10_000 tokens (40_000 chars) → NOT strictly greater
content = "a" * 40_000
tokens = estimate_tokens_rough(content)
assert tokens == _LARGE_TOOL_RESULT_TOKENS
assert not (tokens > _LARGE_TOOL_RESULT_TOKENS)
def test_accumulated_tokens_sum_all_tool_messages(self):
msgs = [
_tool_msg("a" * 4_000), # ~1000 tokens
_tool_msg("b" * 8_000), # ~2000 tokens
_tool_msg("c" * 12_000), # ~3000 tokens
_user_msg("ignored"), # not a tool message
]
total = 0
for m in msgs:
if m.get("role") == "tool":
total += estimate_tokens_rough(m.get("content") or "")
assert total == 6_000 # 1k + 2k + 3k
def test_non_tool_messages_excluded_from_accumulation(self):
msgs = [
_user_msg("big user text " * 5_000), # large but role != tool
_tool_msg("small"),
]
total = 0
for m in msgs:
if m.get("role") == "tool":
total += estimate_tokens_rough(m.get("content") or "")
small_tokens = estimate_tokens_rough("small")
assert total == small_tokens
# ---------------------------------------------------------------------------
# Test 2: Token estimate update includes tool-result tokens
# ---------------------------------------------------------------------------
class TestTokenEstimateIncludesToolResults:
"""When the API reports prompt+completion tokens (pre-tool), the guard
should add the tool-result estimate so the pressure check is accurate."""
def test_tool_result_tokens_added_to_api_reported_count(self):
# Simulate: API reported 80_000 tokens before tool execution.
# Tool results add ~5_000 tokens.
api_prompt_tokens = 75_000
api_completion_tokens = 5_000
tool_result_tokens_added = 5_000 # rough estimate for 20_000 chars
real_tokens = api_prompt_tokens + api_completion_tokens + tool_result_tokens_added
assert real_tokens == 85_000
def test_large_tool_result_can_push_past_pressure_threshold(self):
# Threshold at 100_000 tokens; API reports 82_000 (82% of threshold).
# Without tool results: below 85% → no warning.
# With 4_000 tool tokens: 86% → warning.
threshold = 100_000
api_tokens = 82_000
tool_tokens = 4_000
without_tools = api_tokens / threshold
with_tools = (api_tokens + tool_tokens) / threshold
assert without_tools < 0.85
assert with_tools >= 0.85
def test_small_tool_result_does_not_falsely_trigger_warning(self):
# Start at 70%; tiny result adds 100 tokens — stays below 85%.
threshold = 100_000
api_tokens = 70_000
tool_tokens = 100
progress = (api_tokens + tool_tokens) / threshold
assert progress < 0.85
# ---------------------------------------------------------------------------
# Test 3: AIAgent._vprint is called for large results
# ---------------------------------------------------------------------------
def _make_agent():
with (
patch("run_agent.get_tool_definitions", return_value=[]),
patch("run_agent.check_toolset_requirements", return_value={}),
patch("run_agent.OpenAI"),
):
from run_agent import AIAgent
a = AIAgent(
api_key="test-key-12345",
quiet_mode=True,
skip_context_files=True,
skip_memory=True,
)
a.client = MagicMock()
return a
class TestAgentLargeToolResultWarning:
"""Verify that the agent emits a _vprint warning for large tool results."""
def _simulate_post_tool_check(self, agent, tool_messages: list) -> list[str]:
"""Run the post-tool guard loop and collect _vprint calls."""
printed: list[str] = []
agent._vprint = lambda msg, **_kw: printed.append(msg)
for _tr_msg in tool_messages:
if _tr_msg.get("role") == "tool":
_tr_content = _tr_msg.get("content") or ""
_tr_tokens = estimate_tokens_rough(_tr_content)
if _tr_tokens > _LARGE_TOOL_RESULT_TOKENS:
agent._vprint(
f"{agent.log_prefix}⚠️ Large tool result: "
f"~{_tr_tokens:,} tokens added to context."
)
return printed
def test_large_result_prints_warning(self):
agent = _make_agent()
large_content = "x" * 50_000 # ~12_500 tokens
msgs = [_tool_msg(large_content)]
warnings = self._simulate_post_tool_check(agent, msgs)
assert len(warnings) == 1
assert "Large tool result" in warnings[0]
assert "tokens added to context" in warnings[0]
def test_small_result_no_warning(self):
agent = _make_agent()
small_content = "hello world"
msgs = [_tool_msg(small_content)]
warnings = self._simulate_post_tool_check(agent, msgs)
assert warnings == []
def test_two_large_results_two_warnings(self):
agent = _make_agent()
large = "y" * 50_000
msgs = [
_tool_msg(large, "call_1"),
_tool_msg(large, "call_2"),
]
warnings = self._simulate_post_tool_check(agent, msgs)
assert len(warnings) == 2
def test_mixed_sizes_only_large_warns(self):
agent = _make_agent()
msgs = [
_tool_msg("small result"), # tiny
_tool_msg("z" * 50_000, "call_2"), # large
]
warnings = self._simulate_post_tool_check(agent, msgs)
assert len(warnings) == 1
assert "Large tool result" in warnings[0]