fix(agent): include tool tokens in preflight estimate, guard context probe persistence (#3164)

Two improvements salvaged from PR #2600 (paraddox):

1. Preflight compression now counts tool schema tokens alongside system
   prompt and messages.  With 50+ tools enabled, schemas can add 20-30K
   tokens that were previously invisible to the estimator, delaying
   compression until the API rejected the request.

2. Context probe persistence guard: when the agent steps down context
   tiers after a context-length error, only provider-confirmed numeric
   limits (parsed from the error message) are cached to disk.  Guessed
   fallback tiers from get_next_probe_tier() stay in-memory only,
   preventing wrong values from polluting the persistent cache.

Co-authored-by: paraddox <paraddox@users.noreply.github.com>
This commit is contained in:
Teknium
2026-03-26 02:00:50 -07:00
committed by GitHub
parent 9989e579da
commit 43af094ae3
2 changed files with 52 additions and 10 deletions

View File

@@ -895,3 +895,26 @@ def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
"""Rough token estimate for a message list (pre-flight only)."""
total_chars = sum(len(str(msg)) for msg in messages)
return total_chars // 4
def estimate_request_tokens_rough(
messages: List[Dict[str, Any]],
*,
system_prompt: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
) -> int:
"""Rough token estimate for a full chat-completions request.
Includes the major payload buckets Hermes sends to providers:
system prompt, conversation messages, and tool schemas. With 50+
tools enabled, schemas alone can add 20-30K tokens — a significant
blind spot when only counting messages.
"""
total_chars = 0
if system_prompt:
total_chars += len(system_prompt)
if messages:
total_chars += sum(len(str(msg)) for msg in messages)
if tools:
total_chars += len(str(tools))
return total_chars // 4

View File

@@ -77,7 +77,7 @@ from agent.prompt_builder import (
)
from agent.model_metadata import (
fetch_model_metadata,
estimate_tokens_rough, estimate_messages_tokens_rough,
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
get_next_probe_tier, parse_context_limit_from_error,
save_context_length,
)
@@ -1133,6 +1133,7 @@ class AIAgent:
self.context_compressor.last_total_tokens = 0
self.context_compressor.compression_count = 0
self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
# Iterative summary from previous session must not bleed into new one (#2635)
self.context_compressor._previous_summary = None
@@ -5820,9 +5821,13 @@ class AIAgent:
and len(messages) > self.context_compressor.protect_first_n
+ self.context_compressor.protect_last_n + 1
):
_sys_tok_est = estimate_tokens_rough(active_system_prompt or "")
_msg_tok_est = estimate_messages_tokens_rough(messages)
_preflight_tokens = _sys_tok_est + _msg_tok_est
# Include tool schema tokens — with many tools these can add
# 20-30K+ tokens that the old sys+msg estimate missed entirely.
_preflight_tokens = estimate_request_tokens_rough(
messages,
system_prompt=active_system_prompt or "",
tools=self.tools or None,
)
if _preflight_tokens >= self.context_compressor.threshold_tokens:
logger.info(
@@ -5848,9 +5853,11 @@ class AIAgent:
if len(messages) >= _orig_len:
break # Cannot compress further
# Re-estimate after compression
_sys_tok_est = estimate_tokens_rough(active_system_prompt or "")
_msg_tok_est = estimate_messages_tokens_rough(messages)
_preflight_tokens = _sys_tok_est + _msg_tok_est
_preflight_tokens = estimate_request_tokens_rough(
messages,
system_prompt=active_system_prompt or "",
tools=self.tools or None,
)
if _preflight_tokens < self.context_compressor.threshold_tokens:
break # Under threshold
@@ -6313,12 +6320,16 @@ class AIAgent:
}
self.context_compressor.update_from_response(usage_dict)
# Cache discovered context length after successful call
# Cache discovered context length after successful call.
# Only persist limits confirmed by the provider (parsed
# from the error message), not guessed probe tiers.
if self.context_compressor._context_probed:
ctx = self.context_compressor.context_length
save_context_length(self.model, self.base_url, ctx)
self._safe_print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
if getattr(self.context_compressor, "_context_probe_persistable", False):
save_context_length(self.model, self.base_url, ctx)
self._safe_print(f"{self.log_prefix}💾 Cached context length: {ctx:,} tokens for {self.model}")
self.context_compressor._context_probed = False
self.context_compressor._context_probe_persistable = False
self.session_prompt_tokens += prompt_tokens
self.session_completion_tokens += completion_tokens
@@ -6619,6 +6630,14 @@ class AIAgent:
compressor.context_length = new_ctx
compressor.threshold_tokens = int(new_ctx * compressor.threshold_percent)
compressor._context_probed = True
# Only persist limits parsed from the provider's
# error message (a real number). Guessed fallback
# tiers from get_next_probe_tier() should stay
# in-memory only — persisting them pollutes the
# cache with wrong values.
compressor._context_probe_persistable = bool(
parsed_limit and parsed_limit == new_ctx
)
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded — stepping down: {old_ctx:,}{new_ctx:,} tokens", force=True)
else:
self._vprint(f"{self.log_prefix}⚠️ Context length exceeded at minimum tier — attempting compression...", force=True)