feat(gateway): cache AIAgent per session for prompt caching

The gateway created a fresh AIAgent per message, rebuilding the system
prompt (including memory, skills, context files) every turn. This broke
prompt prefix caching — providers like Anthropic charge ~10x more for
uncached prefixes.

Now caches AIAgent instances per session_key with a config signature.
The cached agent is reused across messages in the same session,
preserving the frozen system prompt and tool schemas. Cache is
invalidated when:
- Config changes (model, provider, toolsets, reasoning, ephemeral
  prompt) — detected via signature mismatch
- /new, /reset, /clear — explicit session reset
- /model — global model change clears all cached agents
- /reasoning — global reasoning change clears all cached agents

Per-message state (callbacks, stream consumers, progress queues) is
set on the agent instance before each run_conversation() call.

This matches CLI behavior where a single AIAgent lives across all turns
in a session, with _cached_system_prompt built once and reused.
This commit is contained in:
Teknium
2026-03-21 13:07:08 -07:00
parent decc7851f2
commit 342096b4bd
3 changed files with 348 additions and 28 deletions

View File

@@ -344,6 +344,15 @@ class GatewayRunner:
self._running_agents: Dict[str, Any] = {}
self._pending_messages: Dict[str, str] = {} # Queued messages during interrupt
# Cache AIAgent instances per session to preserve prompt caching.
# Without this, a new AIAgent is created per message, rebuilding the
# system prompt (including memory) every turn — breaking prefix cache
# and costing ~10x more on providers with prompt caching (Anthropic).
# Key: session_key, Value: (AIAgent, config_signature_str)
import threading as _threading
self._agent_cache: Dict[str, tuple] = {}
self._agent_cache_lock = _threading.Lock()
# Track active fallback model/provider when primary is rate-limited.
# Set after an agent run where fallback was activated; cleared when
# the primary model succeeds again or the user switches via /model.
@@ -2339,6 +2348,7 @@ class GatewayRunner:
logger.debug("Gateway memory flush on reset failed: %s", e)
self._shutdown_gateway_honcho(session_key)
self._evict_cached_agent(session_key)
# Reset the session
new_entry = self.session_store.reset_session(session_key)
@@ -4364,6 +4374,45 @@ class GatewayRunner:
_MAX_INTERRUPT_DEPTH = 3 # Cap recursive interrupt handling (#816)
@staticmethod
def _agent_config_signature(
model: str,
runtime: dict,
enabled_toolsets: list,
ephemeral_prompt: str,
) -> str:
"""Compute a stable string key from agent config values.
When this signature changes between messages, the cached AIAgent is
discarded and rebuilt. When it stays the same, the cached agent is
reused — preserving the frozen system prompt and tool schemas for
prompt cache hits.
"""
import hashlib, json as _j
blob = _j.dumps(
[
model,
runtime.get("api_key", "")[:8], # first 8 chars only
runtime.get("base_url", ""),
runtime.get("provider", ""),
runtime.get("api_mode", ""),
sorted(enabled_toolsets) if enabled_toolsets else [],
# reasoning_config excluded — it's set per-message on the
# cached agent and doesn't affect system prompt or tools.
ephemeral_prompt or "",
],
sort_keys=True,
default=str,
)
return hashlib.sha256(blob.encode()).hexdigest()[:16]
def _evict_cached_agent(self, session_key: str) -> None:
"""Remove a cached agent for a session (called on /new, /model, etc)."""
_lock = getattr(self, "_agent_cache_lock", None)
if _lock:
with _lock:
self._agent_cache.pop(session_key, None)
async def _run_agent(
self,
message: str,
@@ -4713,34 +4762,64 @@ class GatewayRunner:
logger.debug("Could not set up stream consumer: %s", _sc_err)
turn_route = self._resolve_turn_agent_config(message, model, runtime_kwargs)
agent = AIAgent(
model=turn_route["model"],
**turn_route["runtime"],
max_iterations=max_iterations,
quiet_mode=True,
verbose_logging=False,
enabled_toolsets=enabled_toolsets,
ephemeral_system_prompt=combined_ephemeral or None,
prefill_messages=self._prefill_messages or None,
reasoning_config=reasoning_config,
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
provider_require_parameters=pr.get("require_parameters", False),
provider_data_collection=pr.get("data_collection"),
session_id=session_id,
tool_progress_callback=progress_callback if tool_progress_enabled else None,
step_callback=_step_callback_sync if _hooks_ref.loaded_hooks else None,
stream_delta_callback=_stream_delta_cb,
status_callback=_status_callback_sync,
platform=platform_key,
honcho_session_key=session_key,
honcho_manager=honcho_manager,
honcho_config=honcho_config,
session_db=self._session_db,
fallback_model=self._fallback_model,
# Check agent cache — reuse the AIAgent from the previous message
# in this session to preserve the frozen system prompt and tool
# schemas for prompt cache hits.
_sig = self._agent_config_signature(
turn_route["model"],
turn_route["runtime"],
enabled_toolsets,
combined_ephemeral,
)
agent = None
_cache_lock = getattr(self, "_agent_cache_lock", None)
_cache = getattr(self, "_agent_cache", None)
if _cache_lock and _cache is not None:
with _cache_lock:
cached = _cache.get(session_key)
if cached and cached[1] == _sig:
agent = cached[0]
logger.debug("Reusing cached agent for session %s", session_key)
if agent is None:
# Config changed or first message — create fresh agent
agent = AIAgent(
model=turn_route["model"],
**turn_route["runtime"],
max_iterations=max_iterations,
quiet_mode=True,
verbose_logging=False,
enabled_toolsets=enabled_toolsets,
ephemeral_system_prompt=combined_ephemeral or None,
prefill_messages=self._prefill_messages or None,
reasoning_config=reasoning_config,
providers_allowed=pr.get("only"),
providers_ignored=pr.get("ignore"),
providers_order=pr.get("order"),
provider_sort=pr.get("sort"),
provider_require_parameters=pr.get("require_parameters", False),
provider_data_collection=pr.get("data_collection"),
session_id=session_id,
platform=platform_key,
honcho_session_key=session_key,
honcho_manager=honcho_manager,
honcho_config=honcho_config,
session_db=self._session_db,
fallback_model=self._fallback_model,
)
if _cache_lock and _cache is not None:
with _cache_lock:
_cache[session_key] = (agent, _sig)
logger.debug("Created new agent for session %s (sig=%s)", session_key, _sig)
# Per-message state — callbacks and reasoning config change every
# turn and must not be baked into the cached agent constructor.
agent.tool_progress_callback = progress_callback if tool_progress_enabled else None
agent.step_callback = _step_callback_sync if _hooks_ref.loaded_hooks else None
agent.stream_delta_callback = _stream_delta_cb
agent.status_callback = _status_callback_sync
agent.reasoning_config = reasoning_config
# Store agent reference for interrupt support
agent_holder[0] = agent
@@ -4985,6 +5064,9 @@ class GatewayRunner:
if _agent.model != _cfg_model:
self._effective_model = _agent.model
self._effective_provider = getattr(_agent, 'provider', None)
# Fallback activated — evict cached agent so the next
# message starts fresh and retries the primary model.
self._evict_cached_agent(session_key)
else:
# Primary model worked — clear any stale fallback state
self._effective_model = None