From a8df7f996404f0786a7e6ef0ee6486cefaad7431 Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Thu, 26 Mar 2026 19:04:53 -0700 Subject: [PATCH] fix: gateway token double-counting with cached agents (#3306) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The cached agent accumulates session_input_tokens across messages, so run_conversation() returns cumulative totals. But update_session() used += (increment), double-counting on every message after the first. - session.py: change in-memory entry updates from += to = (direct assignment for cumulative values) - hermes_state.py: add absolute=True flag to update_token_counts() that uses SET column = ? instead of SET column = column + ? - session.py: pass absolute=True to the DB call CLI path is unchanged — it passes per-API-call deltas directly to update_token_counts() with the default absolute=False (increment). Reported by @zaycruz in #3222. Closes #3222. --- gateway/session.py | 13 ++++++----- hermes_state.py | 41 ++++++++++++++++++++++++++++++----- tests/gateway/test_session.py | 1 + 3 files changed, 45 insertions(+), 10 deletions(-) diff --git a/gateway/session.py b/gateway/session.py index d22c6d043..2d5376b07 100644 --- a/gateway/session.py +++ b/gateway/session.py @@ -762,14 +762,16 @@ class SessionStore: if session_key in self._entries: entry = self._entries[session_key] entry.updated_at = _now() - entry.input_tokens += input_tokens - entry.output_tokens += output_tokens - entry.cache_read_tokens += cache_read_tokens - entry.cache_write_tokens += cache_write_tokens + # Direct assignment — the gateway receives cumulative totals + # from the cached agent, not per-call deltas. + entry.input_tokens = input_tokens + entry.output_tokens = output_tokens + entry.cache_read_tokens = cache_read_tokens + entry.cache_write_tokens = cache_write_tokens if last_prompt_tokens is not None: entry.last_prompt_tokens = last_prompt_tokens if estimated_cost_usd is not None: - entry.estimated_cost_usd += estimated_cost_usd + entry.estimated_cost_usd = estimated_cost_usd if cost_status: entry.cost_status = cost_status entry.total_tokens = ( @@ -795,6 +797,7 @@ class SessionStore: billing_provider=provider, billing_base_url=base_url, model=model, + absolute=True, ) except Exception as e: logger.debug("Session DB operation failed: %s", e) diff --git a/hermes_state.py b/hermes_state.py index 31ed12190..cf03951c7 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -319,11 +319,39 @@ class SessionDB: billing_provider: Optional[str] = None, billing_base_url: Optional[str] = None, billing_mode: Optional[str] = None, + absolute: bool = False, ) -> None: - """Increment token counters and backfill model if not already set.""" - with self._lock: - self._conn.execute( - """UPDATE sessions SET + """Update token counters and backfill model if not already set. + + When *absolute* is False (default), values are **incremented** — use + this for per-API-call deltas (CLI path). + + When *absolute* is True, values are **set directly** — use this when + the caller already holds cumulative totals (gateway path, where the + cached agent accumulates across messages). + """ + if absolute: + sql = """UPDATE sessions SET + input_tokens = ?, + output_tokens = ?, + cache_read_tokens = ?, + cache_write_tokens = ?, + reasoning_tokens = ?, + estimated_cost_usd = COALESCE(?, 0), + actual_cost_usd = CASE + WHEN ? IS NULL THEN actual_cost_usd + ELSE ? + END, + cost_status = COALESCE(?, cost_status), + cost_source = COALESCE(?, cost_source), + pricing_version = COALESCE(?, pricing_version), + billing_provider = COALESCE(billing_provider, ?), + billing_base_url = COALESCE(billing_base_url, ?), + billing_mode = COALESCE(billing_mode, ?), + model = COALESCE(model, ?) + WHERE id = ?""" + else: + sql = """UPDATE sessions SET input_tokens = input_tokens + ?, output_tokens = output_tokens + ?, cache_read_tokens = cache_read_tokens + ?, @@ -341,7 +369,10 @@ class SessionDB: billing_base_url = COALESCE(billing_base_url, ?), billing_mode = COALESCE(billing_mode, ?), model = COALESCE(model, ?) - WHERE id = ?""", + WHERE id = ?""" + with self._lock: + self._conn.execute( + sql, ( input_tokens, output_tokens, diff --git a/tests/gateway/test_session.py b/tests/gateway/test_session.py index 8d4131ab1..226e50593 100644 --- a/tests/gateway/test_session.py +++ b/tests/gateway/test_session.py @@ -858,6 +858,7 @@ class TestLastPromptTokens: billing_provider=None, billing_base_url=None, model="openai/gpt-5.4", + absolute=True, )