Compare commits
1 Commits
queue/350-
...
burn/20260
| Author | SHA1 | Date | |
|---|---|---|---|
| 8400381a0d |
@@ -3067,12 +3067,40 @@ class GatewayRunner:
|
||||
|
||||
# Token counts and model are now persisted by the agent directly.
|
||||
# Keep only last_prompt_tokens here for context-window tracking and
|
||||
# compression decisions.
|
||||
# compression decisions. Also persist input/output token totals
|
||||
# so the SessionEntry (sessions.json) and SQLite reflect actual usage.
|
||||
_input_total = agent_result.get("input_tokens", 0) or 0
|
||||
_output_total = agent_result.get("output_tokens", 0) or 0
|
||||
_total_tokens = agent_result.get("total_tokens", 0) or 0
|
||||
_cost_usd = agent_result.get("estimated_cost_usd")
|
||||
self.session_store.update_session(
|
||||
session_entry.session_key,
|
||||
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
|
||||
input_tokens=_input_total,
|
||||
output_tokens=_output_total,
|
||||
total_tokens=_total_tokens,
|
||||
estimated_cost_usd=_cost_usd,
|
||||
)
|
||||
|
||||
# Persist token totals to SQLite so /insights sees real data.
|
||||
# Use absolute=true because the agent's session_*_tokens already
|
||||
# reflect the running total for this conversation turn.
|
||||
if self._session_db:
|
||||
try:
|
||||
_eff_sid = agent_result.get("session_id") or session_entry.session_id
|
||||
self._session_db.set_token_counts(
|
||||
_eff_sid,
|
||||
input_tokens=_input_total,
|
||||
output_tokens=_output_total,
|
||||
cache_read_tokens=agent_result.get("cache_read_tokens", 0) or 0,
|
||||
cache_write_tokens=agent_result.get("cache_write_tokens", 0) or 0,
|
||||
reasoning_tokens=agent_result.get("reasoning_tokens", 0) or 0,
|
||||
estimated_cost_usd=_cost_usd,
|
||||
model=_resolved_model,
|
||||
)
|
||||
except Exception:
|
||||
pass # never block delivery
|
||||
|
||||
# Auto voice reply: send TTS audio before the text response
|
||||
_already_sent = bool(agent_result.get("already_sent"))
|
||||
if self._should_send_voice_reply(event, response, agent_messages, already_sent=_already_sent):
|
||||
|
||||
@@ -810,6 +810,10 @@ class SessionStore:
|
||||
self,
|
||||
session_key: str,
|
||||
last_prompt_tokens: int = None,
|
||||
input_tokens: int = None,
|
||||
output_tokens: int = None,
|
||||
total_tokens: int = None,
|
||||
estimated_cost_usd: float = None,
|
||||
) -> None:
|
||||
"""Update lightweight session metadata after an interaction."""
|
||||
with self._lock:
|
||||
@@ -820,6 +824,14 @@ class SessionStore:
|
||||
entry.updated_at = _now()
|
||||
if last_prompt_tokens is not None:
|
||||
entry.last_prompt_tokens = last_prompt_tokens
|
||||
if input_tokens is not None:
|
||||
entry.input_tokens = input_tokens
|
||||
if output_tokens is not None:
|
||||
entry.output_tokens = output_tokens
|
||||
if total_tokens is not None:
|
||||
entry.total_tokens = total_tokens
|
||||
if estimated_cost_usd is not None:
|
||||
entry.estimated_cost_usd = estimated_cost_usd
|
||||
self._save()
|
||||
|
||||
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
|
||||
|
||||
107
tests/test_token_tracking_persistence.py
Normal file
107
tests/test_token_tracking_persistence.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Tests for gateway token count persistence to SessionEntry and SessionDB.
|
||||
|
||||
Regression test for #316 — token tracking all zeros. The gateway must
|
||||
propagate input_tokens / output_tokens from the agent result to both the
|
||||
SessionEntry (sessions.json) and the SQLite session DB.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.session import SessionEntry
|
||||
|
||||
|
||||
class TestUpdateSessionTokenFields:
|
||||
"""Verify SessionEntry token fields are updated and serialized correctly."""
|
||||
|
||||
def test_session_entry_to_dict_includes_tokens(self):
|
||||
entry = SessionEntry(
|
||||
session_key="tg:123",
|
||||
session_id="sid-1",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
input_tokens=1000,
|
||||
output_tokens=500,
|
||||
total_tokens=1500,
|
||||
estimated_cost_usd=0.05,
|
||||
)
|
||||
d = entry.to_dict()
|
||||
assert d["input_tokens"] == 1000
|
||||
assert d["output_tokens"] == 500
|
||||
assert d["total_tokens"] == 1500
|
||||
assert d["estimated_cost_usd"] == 0.05
|
||||
|
||||
def test_session_entry_from_dict_restores_tokens(self):
|
||||
now = datetime.now().isoformat()
|
||||
data = {
|
||||
"session_key": "tg:123",
|
||||
"session_id": "sid-1",
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"input_tokens": 42,
|
||||
"output_tokens": 21,
|
||||
"total_tokens": 63,
|
||||
"estimated_cost_usd": 0.001,
|
||||
}
|
||||
entry = SessionEntry.from_dict(data)
|
||||
assert entry.input_tokens == 42
|
||||
assert entry.output_tokens == 21
|
||||
assert entry.total_tokens == 63
|
||||
assert entry.estimated_cost_usd == 0.001
|
||||
|
||||
def test_session_entry_roundtrip_preserves_tokens(self):
|
||||
"""to_dict -> from_dict must preserve all token fields."""
|
||||
entry = SessionEntry(
|
||||
session_key="cron:job7",
|
||||
session_id="sid-7",
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
input_tokens=9999,
|
||||
output_tokens=1234,
|
||||
total_tokens=11233,
|
||||
cache_read_tokens=500,
|
||||
cache_write_tokens=100,
|
||||
estimated_cost_usd=0.42,
|
||||
)
|
||||
restored = SessionEntry.from_dict(entry.to_dict())
|
||||
assert restored.input_tokens == 9999
|
||||
assert restored.output_tokens == 1234
|
||||
assert restored.total_tokens == 11233
|
||||
assert restored.cache_read_tokens == 500
|
||||
assert restored.cache_write_tokens == 100
|
||||
assert restored.estimated_cost_usd == 0.42
|
||||
|
||||
|
||||
class TestAgentResultTokenExtraction:
|
||||
"""Verify the gateway extracts token counts from agent_result correctly."""
|
||||
|
||||
def test_agent_result_has_expected_keys(self):
|
||||
"""Simulate what _run_agent returns and verify all token keys exist."""
|
||||
result = {
|
||||
"final_response": "hello",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"cache_read_tokens": 10,
|
||||
"cache_write_tokens": 5,
|
||||
"reasoning_tokens": 0,
|
||||
"estimated_cost_usd": 0.002,
|
||||
"last_prompt_tokens": 100,
|
||||
"model": "test-model",
|
||||
"session_id": "test-session-123",
|
||||
}
|
||||
# These are the extractions the gateway performs
|
||||
assert result.get("input_tokens", 0) or 0 == 100
|
||||
assert result.get("output_tokens", 0) or 0 == 50
|
||||
assert result.get("total_tokens", 0) or 0 == 150
|
||||
assert result.get("estimated_cost_usd") == 0.002
|
||||
|
||||
def test_agent_result_zero_fallback(self):
|
||||
"""When token keys are missing, defaults to 0."""
|
||||
result = {"final_response": "ok"}
|
||||
assert result.get("input_tokens", 0) or 0 == 0
|
||||
assert result.get("output_tokens", 0) or 0 == 0
|
||||
assert result.get("total_tokens", 0) or 0 == 0
|
||||
Reference in New Issue
Block a user