Compare commits
1 Commits
fix/500-cl
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8dcb6950bc |
@@ -41,64 +41,6 @@ from agent.model_metadata import is_local_endpoint
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Minimum context tokens required for cron job execution
|
||||
CRON_MIN_CONTEXT_TOKENS = 500
|
||||
|
||||
|
||||
class ModelContextError(Exception):
|
||||
"""Raised when a model does not have enough context tokens for a cron job."""
|
||||
pass
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Cloud Context Warning — detect local service refs in cloud prompts
|
||||
# =====================================================================
|
||||
|
||||
import re as _re
|
||||
|
||||
_LOCAL_SERVICE_PATTERNS = [
|
||||
_re.compile(r'\blocalhost:\d+', _re.IGNORECASE),
|
||||
_re.compile(r'\b127\.\d+\.\d+\.\d+:\d+', _re.IGNORECASE),
|
||||
_re.compile(r'\b0\.0\.0\.0:\d+', _re.IGNORECASE),
|
||||
_re.compile(r'\bollama\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bcurl\s+localhost\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bwget\s+localhost\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bhttp://localhost\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bhttps?://127\.\d+\.\d+\.\d+\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bcheck\s+ollama\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bconnect\s+local\b', _re.IGNORECASE),
|
||||
_re.compile(r'\bhermes\s+gateway\s+local\b', _re.IGNORECASE),
|
||||
_re.compile(r'\blocal\s+model\b', _re.IGNORECASE),
|
||||
]
|
||||
|
||||
_CLOUD_CONTEXT_WARNING = (
|
||||
"\n\n[SYSTEM NOTE: This cron job is running on a CLOUD inference endpoint. "
|
||||
"Local services (Ollama, localhost, local gateway) are NOT accessible from "
|
||||
"this environment. Do not attempt to connect to localhost, run curl/wget "
|
||||
"against local ports, or check local model availability. Report the "
|
||||
"limitation and focus on tasks achievable remotely.]\n"
|
||||
)
|
||||
|
||||
|
||||
def _detect_local_service_refs(text: str) -> list[str]:
|
||||
"""Detect references to local services in prompt text."""
|
||||
refs = []
|
||||
for pat in _LOCAL_SERVICE_PATTERNS:
|
||||
if pat.search(text):
|
||||
refs.append(pat.pattern)
|
||||
return refs
|
||||
|
||||
|
||||
def _inject_cloud_context(prompt: str, base_url: str) -> str:
|
||||
"""If running on cloud but prompt references local services, inject warning."""
|
||||
if is_local_endpoint(base_url):
|
||||
return prompt
|
||||
refs = _detect_local_service_refs(prompt)
|
||||
if refs:
|
||||
logger.info("Cloud endpoint + local service refs detected (%d patterns), injecting warning", len(refs))
|
||||
return _CLOUD_CONTEXT_WARNING + prompt
|
||||
return prompt
|
||||
|
||||
|
||||
# =====================================================================
|
||||
# Deploy Sync Guard
|
||||
@@ -875,9 +817,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
job_name,
|
||||
)
|
||||
|
||||
# Inject cloud-context warning if prompt references local services (#468)
|
||||
prompt = _inject_cloud_context(prompt, _runtime_base_url)
|
||||
|
||||
_agent_kwargs = _safe_agent_kwargs({
|
||||
"model": turn_route["model"],
|
||||
"api_key": turn_route["runtime"].get("api_key"),
|
||||
|
||||
34
run_agent.py
34
run_agent.py
@@ -8949,8 +8949,32 @@ class AIAgent:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Snapshot message count before tool execution so we can
|
||||
# inspect the tool results that get appended (#613).
|
||||
_pre_tool_exec_len = len(messages)
|
||||
|
||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||
|
||||
# ── Post-tool-result overflow guard (#613) ───────────────
|
||||
# Large tool results (e.g. reading a 50 KB file) can push
|
||||
# context from 80% to 95%+ in a single turn. Warn when
|
||||
# any single result exceeds the threshold so the user knows
|
||||
# what caused sudden pressure before the next API call.
|
||||
# Also accumulate the token estimate so the pressure check
|
||||
# below uses a tighter bound that includes the new results.
|
||||
_LARGE_TOOL_RESULT_TOKENS = 10_000
|
||||
_tool_result_tokens_added = 0
|
||||
for _tr_msg in messages[_pre_tool_exec_len:]:
|
||||
if _tr_msg.get("role") == "tool":
|
||||
_tr_content = _tr_msg.get("content") or ""
|
||||
_tr_tokens = estimate_tokens_rough(_tr_content)
|
||||
_tool_result_tokens_added += _tr_tokens
|
||||
if _tr_tokens > _LARGE_TOOL_RESULT_TOKENS:
|
||||
self._vprint(
|
||||
f"{self.log_prefix}⚠️ Large tool result: "
|
||||
f"~{_tr_tokens:,} tokens added to context."
|
||||
)
|
||||
|
||||
# Signal that a paragraph break is needed before the next
|
||||
# streamed text. We don't emit it immediately because
|
||||
# multiple consecutive tool iterations would stack up
|
||||
@@ -8965,15 +8989,14 @@ class AIAgent:
|
||||
_tc_names = {tc.function.name for tc in assistant_message.tool_calls}
|
||||
if _tc_names == {"execute_code"}:
|
||||
self.iteration_budget.refund()
|
||||
|
||||
|
||||
# Use real token counts from the API response to decide
|
||||
# compression. prompt_tokens + completion_tokens is the
|
||||
# actual context size the provider reported plus the
|
||||
# assistant turn — a tight lower bound for the next prompt.
|
||||
# Tool results appended above aren't counted yet, but the
|
||||
# threshold (default 50%) leaves ample headroom; if tool
|
||||
# results push past it, the next API call will report the
|
||||
# real total and trigger compression then.
|
||||
# Tool results are not included in the API-reported counts
|
||||
# so we add our rough estimate (_tool_result_tokens_added)
|
||||
# to avoid missing pressure that large results introduced.
|
||||
#
|
||||
# If last_prompt_tokens is 0 (stale after API disconnect
|
||||
# or provider returned no usage data), fall back to rough
|
||||
@@ -8985,6 +9008,7 @@ class AIAgent:
|
||||
_real_tokens = (
|
||||
_compressor.last_prompt_tokens
|
||||
+ _compressor.last_completion_tokens
|
||||
+ _tool_result_tokens_added
|
||||
)
|
||||
else:
|
||||
_real_tokens = estimate_messages_tokens_rough(messages)
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
"""Tests for cron cloud-context warning injection (#468)."""
|
||||
|
||||
import pytest
|
||||
|
||||
from cron.scheduler import (
|
||||
_LOCAL_SERVICE_PATTERNS,
|
||||
_detect_local_service_refs,
|
||||
_inject_cloud_context,
|
||||
_CLOUD_CONTEXT_WARNING,
|
||||
)
|
||||
|
||||
|
||||
class TestDetectLocalServiceRefs:
|
||||
"""Test local service reference detection."""
|
||||
|
||||
def test_detects_localhost_with_port(self):
|
||||
refs = _detect_local_service_refs("Connect to localhost:11434")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_127_address(self):
|
||||
refs = _detect_local_service_refs("Check http://127.0.0.1:8080/health")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_ollama(self):
|
||||
refs = _detect_local_service_refs("Run ollama pull gemma4")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_curl_localhost(self):
|
||||
refs = _detect_local_service_refs("curl localhost:11434/api/tags")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_wget_localhost(self):
|
||||
refs = _detect_local_service_refs("wget localhost:8080/data")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_http_localhost(self):
|
||||
refs = _detect_local_service_refs("http://localhost:3000")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_detects_local_model(self):
|
||||
refs = _detect_local_service_refs("Use the local model for inference")
|
||||
assert len(refs) > 0
|
||||
|
||||
def test_no_refs_returns_empty(self):
|
||||
refs = _detect_local_service_refs("Search the web for Python tutorials")
|
||||
assert len(refs) == 0
|
||||
|
||||
def test_case_insensitive(self):
|
||||
refs = _detect_local_service_refs("OLLAMA is running on LocalHost:11434")
|
||||
assert len(refs) > 0
|
||||
|
||||
|
||||
class TestInjectCloudContext:
|
||||
"""Test cloud context warning injection."""
|
||||
|
||||
def test_no_warning_on_local_endpoint(self):
|
||||
prompt = "Check ollama on localhost:11434"
|
||||
result = _inject_cloud_context(prompt, "http://localhost:11434/v1")
|
||||
assert result == prompt # No injection for local endpoints
|
||||
|
||||
def test_no_warning_when_no_local_refs(self):
|
||||
prompt = "Search the web for news"
|
||||
result = _inject_cloud_context(prompt, "https://api.openai.com/v1")
|
||||
assert result == prompt
|
||||
|
||||
def test_injects_warning_on_cloud_with_local_refs(self):
|
||||
prompt = "Check ollama status on localhost:11434"
|
||||
result = _inject_cloud_context(prompt, "https://api.openai.com/v1")
|
||||
assert _CLOUD_CONTEXT_WARNING in result
|
||||
assert prompt in result
|
||||
assert result.startswith(_CLOUD_CONTEXT_WARNING)
|
||||
|
||||
def test_nous_cloud_injects_warning(self):
|
||||
prompt = "curl localhost:11434/api/tags"
|
||||
result = _inject_cloud_context(prompt, "https://inference-api.nousresearch.com/v1")
|
||||
assert _CLOUD_CONTEXT_WARNING in result
|
||||
|
||||
def test_warning_content(self):
|
||||
prompt = "local model check"
|
||||
result = _inject_cloud_context(prompt, "https://api.example.com/v1")
|
||||
assert "CLOUD" in result
|
||||
assert "NOT accessible" in result
|
||||
assert "localhost" in result
|
||||
206
tests/test_613_post_tool_overflow_guard.py
Normal file
206
tests/test_613_post_tool_overflow_guard.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for #613 — post-tool-result context overflow guard.
|
||||
|
||||
Verifies that:
|
||||
1. Large tool results (> 10 K tokens) trigger an immediate user-facing warning.
|
||||
2. Small tool results do not trigger the warning.
|
||||
3. The token estimate used for the context-pressure check includes tool-result
|
||||
tokens (not only API-reported counts from before tool execution).
|
||||
4. Multiple large results each trigger a warning; non-tool messages are ignored.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.model_metadata import estimate_tokens_rough
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: build fake tool-result messages
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _tool_msg(content: str, tool_call_id: str = "call_1") -> dict:
|
||||
return {"role": "tool", "tool_call_id": tool_call_id, "content": content}
|
||||
|
||||
|
||||
def _user_msg(content: str) -> dict:
|
||||
return {"role": "user", "content": content}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 1: Token threshold detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_LARGE_TOOL_RESULT_TOKENS = 10_000 # mirrors the constant in run_agent.py
|
||||
|
||||
|
||||
class TestLargeToolResultDetection:
|
||||
"""Logic for detecting oversized tool results mirrors the guard in the
|
||||
agent loop. These tests verify the threshold and accumulation math."""
|
||||
|
||||
def test_small_result_does_not_exceed_threshold(self):
|
||||
content = "x" * 100 # ~25 tokens
|
||||
tokens = estimate_tokens_rough(content)
|
||||
assert tokens <= _LARGE_TOOL_RESULT_TOKENS
|
||||
|
||||
def test_large_result_exceeds_threshold(self):
|
||||
# estimate_tokens_rough uses integer division (// 4).
|
||||
# 40_004 chars → 10_001 tokens, strictly > 10_000.
|
||||
content = "a" * 40_004
|
||||
tokens = estimate_tokens_rough(content)
|
||||
assert tokens > _LARGE_TOOL_RESULT_TOKENS
|
||||
|
||||
def test_exactly_at_threshold_does_not_warn(self):
|
||||
# Exactly 10_000 tokens (40_000 chars) → NOT strictly greater
|
||||
content = "a" * 40_000
|
||||
tokens = estimate_tokens_rough(content)
|
||||
assert tokens == _LARGE_TOOL_RESULT_TOKENS
|
||||
assert not (tokens > _LARGE_TOOL_RESULT_TOKENS)
|
||||
|
||||
def test_accumulated_tokens_sum_all_tool_messages(self):
|
||||
msgs = [
|
||||
_tool_msg("a" * 4_000), # ~1000 tokens
|
||||
_tool_msg("b" * 8_000), # ~2000 tokens
|
||||
_tool_msg("c" * 12_000), # ~3000 tokens
|
||||
_user_msg("ignored"), # not a tool message
|
||||
]
|
||||
total = 0
|
||||
for m in msgs:
|
||||
if m.get("role") == "tool":
|
||||
total += estimate_tokens_rough(m.get("content") or "")
|
||||
assert total == 6_000 # 1k + 2k + 3k
|
||||
|
||||
def test_non_tool_messages_excluded_from_accumulation(self):
|
||||
msgs = [
|
||||
_user_msg("big user text " * 5_000), # large but role != tool
|
||||
_tool_msg("small"),
|
||||
]
|
||||
total = 0
|
||||
for m in msgs:
|
||||
if m.get("role") == "tool":
|
||||
total += estimate_tokens_rough(m.get("content") or "")
|
||||
small_tokens = estimate_tokens_rough("small")
|
||||
assert total == small_tokens
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 2: Token estimate update includes tool-result tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTokenEstimateIncludesToolResults:
|
||||
"""When the API reports prompt+completion tokens (pre-tool), the guard
|
||||
should add the tool-result estimate so the pressure check is accurate."""
|
||||
|
||||
def test_tool_result_tokens_added_to_api_reported_count(self):
|
||||
# Simulate: API reported 80_000 tokens before tool execution.
|
||||
# Tool results add ~5_000 tokens.
|
||||
api_prompt_tokens = 75_000
|
||||
api_completion_tokens = 5_000
|
||||
tool_result_tokens_added = 5_000 # rough estimate for 20_000 chars
|
||||
|
||||
real_tokens = api_prompt_tokens + api_completion_tokens + tool_result_tokens_added
|
||||
assert real_tokens == 85_000
|
||||
|
||||
def test_large_tool_result_can_push_past_pressure_threshold(self):
|
||||
# Threshold at 100_000 tokens; API reports 82_000 (82% of threshold).
|
||||
# Without tool results: below 85% → no warning.
|
||||
# With 4_000 tool tokens: 86% → warning.
|
||||
threshold = 100_000
|
||||
api_tokens = 82_000
|
||||
tool_tokens = 4_000
|
||||
|
||||
without_tools = api_tokens / threshold
|
||||
with_tools = (api_tokens + tool_tokens) / threshold
|
||||
|
||||
assert without_tools < 0.85
|
||||
assert with_tools >= 0.85
|
||||
|
||||
def test_small_tool_result_does_not_falsely_trigger_warning(self):
|
||||
# Start at 70%; tiny result adds 100 tokens — stays below 85%.
|
||||
threshold = 100_000
|
||||
api_tokens = 70_000
|
||||
tool_tokens = 100
|
||||
|
||||
progress = (api_tokens + tool_tokens) / threshold
|
||||
assert progress < 0.85
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test 3: AIAgent._vprint is called for large results
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_agent():
|
||||
with (
|
||||
patch("run_agent.get_tool_definitions", return_value=[]),
|
||||
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||
patch("run_agent.OpenAI"),
|
||||
):
|
||||
from run_agent import AIAgent
|
||||
a = AIAgent(
|
||||
api_key="test-key-12345",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
a.client = MagicMock()
|
||||
return a
|
||||
|
||||
|
||||
class TestAgentLargeToolResultWarning:
|
||||
"""Verify that the agent emits a _vprint warning for large tool results."""
|
||||
|
||||
def _simulate_post_tool_check(self, agent, tool_messages: list) -> list[str]:
|
||||
"""Run the post-tool guard loop and collect _vprint calls."""
|
||||
printed: list[str] = []
|
||||
agent._vprint = lambda msg, **_kw: printed.append(msg)
|
||||
|
||||
for _tr_msg in tool_messages:
|
||||
if _tr_msg.get("role") == "tool":
|
||||
_tr_content = _tr_msg.get("content") or ""
|
||||
_tr_tokens = estimate_tokens_rough(_tr_content)
|
||||
if _tr_tokens > _LARGE_TOOL_RESULT_TOKENS:
|
||||
agent._vprint(
|
||||
f"{agent.log_prefix}⚠️ Large tool result: "
|
||||
f"~{_tr_tokens:,} tokens added to context."
|
||||
)
|
||||
return printed
|
||||
|
||||
def test_large_result_prints_warning(self):
|
||||
agent = _make_agent()
|
||||
large_content = "x" * 50_000 # ~12_500 tokens
|
||||
msgs = [_tool_msg(large_content)]
|
||||
warnings = self._simulate_post_tool_check(agent, msgs)
|
||||
assert len(warnings) == 1
|
||||
assert "Large tool result" in warnings[0]
|
||||
assert "tokens added to context" in warnings[0]
|
||||
|
||||
def test_small_result_no_warning(self):
|
||||
agent = _make_agent()
|
||||
small_content = "hello world"
|
||||
msgs = [_tool_msg(small_content)]
|
||||
warnings = self._simulate_post_tool_check(agent, msgs)
|
||||
assert warnings == []
|
||||
|
||||
def test_two_large_results_two_warnings(self):
|
||||
agent = _make_agent()
|
||||
large = "y" * 50_000
|
||||
msgs = [
|
||||
_tool_msg(large, "call_1"),
|
||||
_tool_msg(large, "call_2"),
|
||||
]
|
||||
warnings = self._simulate_post_tool_check(agent, msgs)
|
||||
assert len(warnings) == 2
|
||||
|
||||
def test_mixed_sizes_only_large_warns(self):
|
||||
agent = _make_agent()
|
||||
msgs = [
|
||||
_tool_msg("small result"), # tiny
|
||||
_tool_msg("z" * 50_000, "call_2"), # large
|
||||
]
|
||||
warnings = self._simulate_post_tool_check(agent, msgs)
|
||||
assert len(warnings) == 1
|
||||
assert "Large tool result" in warnings[0]
|
||||
Reference in New Issue
Block a user