Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 21s
The gateway's _run_agent returns input_tokens/output_tokens in its result dict, but these were never stored to SessionEntry or the SQLite session DB. Every session showed zero token counts. Changes: - gateway/session.py: Extend update_session() to accept and persist input_tokens, output_tokens, total_tokens, estimated_cost_usd - gateway/run.py: Pass agent result token totals to update_session() and call set_token_counts(absolute=True) on _session_db after every conversation turn - tests/test_token_tracking_persistence.py: Regression tests for SessionEntry serialization and agent result token extraction Closes #316
108 lines
3.8 KiB
Python
108 lines
3.8 KiB
Python
"""Tests for gateway token count persistence to SessionEntry and SessionDB.
|
|
|
|
Regression test for #316 — token tracking all zeros. The gateway must
|
|
propagate input_tokens / output_tokens from the agent result to both the
|
|
SessionEntry (sessions.json) and the SQLite session DB.
|
|
"""
|
|
|
|
import json
|
|
from datetime import datetime
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from gateway.session import SessionEntry
|
|
|
|
|
|
class TestUpdateSessionTokenFields:
|
|
"""Verify SessionEntry token fields are updated and serialized correctly."""
|
|
|
|
def test_session_entry_to_dict_includes_tokens(self):
|
|
entry = SessionEntry(
|
|
session_key="tg:123",
|
|
session_id="sid-1",
|
|
created_at=datetime.now(),
|
|
updated_at=datetime.now(),
|
|
input_tokens=1000,
|
|
output_tokens=500,
|
|
total_tokens=1500,
|
|
estimated_cost_usd=0.05,
|
|
)
|
|
d = entry.to_dict()
|
|
assert d["input_tokens"] == 1000
|
|
assert d["output_tokens"] == 500
|
|
assert d["total_tokens"] == 1500
|
|
assert d["estimated_cost_usd"] == 0.05
|
|
|
|
def test_session_entry_from_dict_restores_tokens(self):
|
|
now = datetime.now().isoformat()
|
|
data = {
|
|
"session_key": "tg:123",
|
|
"session_id": "sid-1",
|
|
"created_at": now,
|
|
"updated_at": now,
|
|
"input_tokens": 42,
|
|
"output_tokens": 21,
|
|
"total_tokens": 63,
|
|
"estimated_cost_usd": 0.001,
|
|
}
|
|
entry = SessionEntry.from_dict(data)
|
|
assert entry.input_tokens == 42
|
|
assert entry.output_tokens == 21
|
|
assert entry.total_tokens == 63
|
|
assert entry.estimated_cost_usd == 0.001
|
|
|
|
def test_session_entry_roundtrip_preserves_tokens(self):
|
|
"""to_dict -> from_dict must preserve all token fields."""
|
|
entry = SessionEntry(
|
|
session_key="cron:job7",
|
|
session_id="sid-7",
|
|
created_at=datetime.now(),
|
|
updated_at=datetime.now(),
|
|
input_tokens=9999,
|
|
output_tokens=1234,
|
|
total_tokens=11233,
|
|
cache_read_tokens=500,
|
|
cache_write_tokens=100,
|
|
estimated_cost_usd=0.42,
|
|
)
|
|
restored = SessionEntry.from_dict(entry.to_dict())
|
|
assert restored.input_tokens == 9999
|
|
assert restored.output_tokens == 1234
|
|
assert restored.total_tokens == 11233
|
|
assert restored.cache_read_tokens == 500
|
|
assert restored.cache_write_tokens == 100
|
|
assert restored.estimated_cost_usd == 0.42
|
|
|
|
|
|
class TestAgentResultTokenExtraction:
|
|
"""Verify the gateway extracts token counts from agent_result correctly."""
|
|
|
|
def test_agent_result_has_expected_keys(self):
|
|
"""Simulate what _run_agent returns and verify all token keys exist."""
|
|
result = {
|
|
"final_response": "hello",
|
|
"input_tokens": 100,
|
|
"output_tokens": 50,
|
|
"total_tokens": 150,
|
|
"cache_read_tokens": 10,
|
|
"cache_write_tokens": 5,
|
|
"reasoning_tokens": 0,
|
|
"estimated_cost_usd": 0.002,
|
|
"last_prompt_tokens": 100,
|
|
"model": "test-model",
|
|
"session_id": "test-session-123",
|
|
}
|
|
# These are the extractions the gateway performs
|
|
assert result.get("input_tokens", 0) or 0 == 100
|
|
assert result.get("output_tokens", 0) or 0 == 50
|
|
assert result.get("total_tokens", 0) or 0 == 150
|
|
assert result.get("estimated_cost_usd") == 0.002
|
|
|
|
def test_agent_result_zero_fallback(self):
|
|
"""When token keys are missing, defaults to 0."""
|
|
result = {"final_response": "ok"}
|
|
assert result.get("input_tokens", 0) or 0 == 0
|
|
assert result.get("output_tokens", 0) or 0 == 0
|
|
assert result.get("total_tokens", 0) or 0 == 0
|