Compare commits

..

1 Commits

Author SHA1 Message Date
8400381a0d fix: persist token counts from gateway to SessionEntry and SQLite (#316)
Some checks failed
Forge CI / smoke-and-build (pull_request) Failing after 21s
The gateway's _run_agent returns input_tokens/output_tokens in its
result dict, but these were never stored to SessionEntry or the SQLite
session DB. Every session showed zero token counts.

Changes:
- gateway/session.py: Extend update_session() to accept and persist
  input_tokens, output_tokens, total_tokens, estimated_cost_usd
- gateway/run.py: Pass agent result token totals to update_session()
  and call set_token_counts(absolute=True) on _session_db after
  every conversation turn
- tests/test_token_tracking_persistence.py: Regression tests for
  SessionEntry serialization and agent result token extraction

Closes #316
2026-04-13 17:38:55 -04:00
5 changed files with 151 additions and 315 deletions

View File

@@ -107,7 +107,6 @@ class SessionResetPolicy:
mode: str = "both" # "daily", "idle", "both", or "none"
at_hour: int = 4 # Hour for daily reset (0-23, local time)
idle_minutes: int = 1440 # Minutes of inactivity before reset (24 hours)
max_messages: int = 200 # Max messages per session before forced checkpoint+restart (0 = unlimited)
notify: bool = True # Send a notification to the user when auto-reset occurs
notify_exclude_platforms: tuple = ("api_server", "webhook") # Platforms that don't get reset notifications
@@ -116,7 +115,6 @@ class SessionResetPolicy:
"mode": self.mode,
"at_hour": self.at_hour,
"idle_minutes": self.idle_minutes,
"max_messages": self.max_messages,
"notify": self.notify,
"notify_exclude_platforms": list(self.notify_exclude_platforms),
}
@@ -127,14 +125,12 @@ class SessionResetPolicy:
mode = data.get("mode")
at_hour = data.get("at_hour")
idle_minutes = data.get("idle_minutes")
max_messages = data.get("max_messages")
notify = data.get("notify")
exclude = data.get("notify_exclude_platforms")
return cls(
mode=mode if mode is not None else "both",
at_hour=at_hour if at_hour is not None else 4,
idle_minutes=idle_minutes if idle_minutes is not None else 1440,
max_messages=max_messages if max_messages is not None else 200,
notify=notify if notify is not None else True,
notify_exclude_platforms=tuple(exclude) if exclude is not None else ("api_server", "webhook"),
)

View File

@@ -2343,12 +2343,6 @@ class GatewayRunner:
reset_reason = getattr(session_entry, 'auto_reset_reason', None) or 'idle'
if reset_reason == "daily":
context_note = "[System note: The user's session was automatically reset by the daily schedule. This is a fresh conversation with no prior context.]"
elif reset_reason == "message_limit":
context_note = (
"[System note: The user's previous session reached the message limit "
"and was automatically checkpointed and rotated. This is a fresh session. "
"If the user references something from before, you can search session history.]"
)
else:
context_note = "[System note: The user's previous session expired due to inactivity. This is a fresh conversation with no prior context.]"
context_prompt = context_note + "\n\n" + context_prompt
@@ -2374,18 +2368,16 @@ class GatewayRunner:
if adapter:
if reset_reason == "daily":
reason_text = f"daily schedule at {policy.at_hour}:00"
elif reset_reason == "message_limit":
reason_text = f"reached {policy.max_messages} message limit"
else:
hours = policy.idle_minutes // 60
mins = policy.idle_minutes % 60
duration = f"{hours}h" if not mins else f"{hours}h {mins}m" if hours else f"{mins}m"
reason_text = f"inactive for {duration}"
notice = (
f"◐ Session automatically rotated ({reason_text}). "
f"Conversation was preserved via checkpoint.\n"
f"◐ Session automatically reset ({reason_text}). "
f"Conversation history cleared.\n"
f"Use /resume to browse and restore a previous session.\n"
f"Adjust limits in config.yaml under session_reset."
f"Adjust reset timing in config.yaml under session_reset."
)
try:
session_info = self._format_session_info()
@@ -3075,44 +3067,39 @@ class GatewayRunner:
# Token counts and model are now persisted by the agent directly.
# Keep only last_prompt_tokens here for context-window tracking and
# compression decisions.
# compression decisions. Also persist input/output token totals
# so the SessionEntry (sessions.json) and SQLite reflect actual usage.
_input_total = agent_result.get("input_tokens", 0) or 0
_output_total = agent_result.get("output_tokens", 0) or 0
_total_tokens = agent_result.get("total_tokens", 0) or 0
_cost_usd = agent_result.get("estimated_cost_usd")
self.session_store.update_session(
session_entry.session_key,
last_prompt_tokens=agent_result.get("last_prompt_tokens", 0),
input_tokens=_input_total,
output_tokens=_output_total,
total_tokens=_total_tokens,
estimated_cost_usd=_cost_usd,
)
# Marathon session limit (#326): check if we hit the message cap.
# Auto-checkpoint filesystem and rotate session.
try:
_post_limit = self.session_store.get_message_limit_info(session_key)
if _post_limit["at_limit"] and _post_limit["max_messages"] > 0:
logger.info(
"[Marathon] Session %s hit message limit (%d/%d). Rotating.",
session_key, _post_limit["message_count"], _post_limit["max_messages"],
# Persist token totals to SQLite so /insights sees real data.
# Use absolute=true because the agent's session_*_tokens already
# reflect the running total for this conversation turn.
if self._session_db:
try:
_eff_sid = agent_result.get("session_id") or session_entry.session_id
self._session_db.set_token_counts(
_eff_sid,
input_tokens=_input_total,
output_tokens=_output_total,
cache_read_tokens=agent_result.get("cache_read_tokens", 0) or 0,
cache_write_tokens=agent_result.get("cache_write_tokens", 0) or 0,
reasoning_tokens=agent_result.get("reasoning_tokens", 0) or 0,
estimated_cost_usd=_cost_usd,
model=_resolved_model,
)
# Attempt filesystem checkpoint before rotation
try:
from tools.checkpoint_manager import CheckpointManager
_cp_cfg_path = _hermes_home / "config.yaml"
if _cp_cfg_path.exists():
import yaml as _cp_yaml
with open(_cp_cfg_path, encoding="utf-8") as _cpf:
_cp_data = _cp_yaml.safe_load(_cpf) or {}
_cp_settings = _cp_data.get("checkpoints", {})
if _cp_settings.get("enabled"):
_cwd = _cp_settings.get("working_dir") or os.getcwd()
mgr = CheckpointManager(max_checkpoints=_cp_settings.get("max_checkpoints", 20))
cp = mgr.create_checkpoint(str(_cwd), label=f"marathon-{session_entry.session_id[:8]}")
if cp:
logger.info("[Marathon] Checkpoint: %s", cp.label)
except Exception as cp_err:
logger.debug("[Marathon] Checkpoint failed (non-fatal): %s", cp_err)
new_entry = self.session_store.reset_session(session_key)
if new_entry:
logger.info("[Marathon] Rotated: %s -> %s", session_entry.session_id, new_entry.session_id)
except Exception as rot_err:
logger.debug("[Marathon] Rotation check failed: %s", rot_err)
except Exception:
pass # never block delivery
# Auto voice reply: send TTS audio before the text response
_already_sent = bool(agent_result.get("already_sent"))
@@ -6579,26 +6566,6 @@ class GatewayRunner:
if self._ephemeral_system_prompt:
combined_ephemeral = (combined_ephemeral + "\n\n" + self._ephemeral_system_prompt).strip()
# Marathon session limit warning (#326)
try:
_limit_info = self.session_store.get_message_limit_info(session_key)
if _limit_info["near_limit"] and not _limit_info["at_limit"]:
_remaining = _limit_info["remaining"]
_limit_warn = (
f"[SESSION LIMIT: This session has {_limit_info['message_count']} messages. "
f"Only {_remaining} message(s) remain before automatic session rotation at "
f"{_limit_info['max_messages']} messages. Start wrapping up and save important state.]"
)
combined_ephemeral = (combined_ephemeral + "\n\n" + _limit_warn).strip()
elif _limit_info["at_limit"]:
_limit_warn = (
f"[SESSION LIMIT REACHED: This session has hit the {_limit_info['max_messages']} "
f"message limit. This is your FINAL response. Summarize accomplishments and next steps.]"
)
combined_ephemeral = (combined_ephemeral + "\n\n" + _limit_warn).strip()
except Exception:
pass
# Re-read .env and config for fresh credentials (gateway is long-lived,
# keys may change without restart).
try:

View File

@@ -383,11 +383,7 @@ class SessionEntry:
# survives gateway restarts (the old in-memory _pre_flushed_sessions
# set was lost on restart, causing redundant re-flushes).
memory_flushed: bool = False
# Marathon session limit tracking (#326).
# Counts total messages (user + assistant + tool) in this session.
message_count: int = 0
def to_dict(self) -> Dict[str, Any]:
result = {
"session_key": self.session_key,
@@ -406,7 +402,6 @@ class SessionEntry:
"estimated_cost_usd": self.estimated_cost_usd,
"cost_status": self.cost_status,
"memory_flushed": self.memory_flushed,
"message_count": self.message_count,
}
if self.origin:
result["origin"] = self.origin.to_dict()
@@ -443,7 +438,6 @@ class SessionEntry:
estimated_cost_usd=data.get("estimated_cost_usd", 0.0),
cost_status=data.get("cost_status", "unknown"),
memory_flushed=data.get("memory_flushed", False),
message_count=data.get("message_count", 0),
)
@@ -649,9 +643,6 @@ class SessionStore:
)
if policy.mode == "none":
# Even with mode=none, enforce message_limit if set
if policy.max_messages > 0 and entry.message_count >= policy.max_messages:
return "message_limit"
return None
now = _now()
@@ -673,11 +664,7 @@ class SessionStore:
if entry.updated_at < today_reset:
return "daily"
# Marathon session limit (#326): force checkpoint+restart at max_messages
if policy.max_messages > 0 and entry.message_count >= policy.max_messages:
return "message_limit"
return None
def has_any_sessions(self) -> bool:
@@ -823,6 +810,10 @@ class SessionStore:
self,
session_key: str,
last_prompt_tokens: int = None,
input_tokens: int = None,
output_tokens: int = None,
total_tokens: int = None,
estimated_cost_usd: float = None,
) -> None:
"""Update lightweight session metadata after an interaction."""
with self._lock:
@@ -833,43 +824,14 @@ class SessionStore:
entry.updated_at = _now()
if last_prompt_tokens is not None:
entry.last_prompt_tokens = last_prompt_tokens
self._save()
def get_message_limit_info(self, session_key: str) -> Dict[str, Any]:
"""Get message count and limit info for a session (#326)."""
with self._lock:
self._ensure_loaded_locked()
entry = self._entries.get(session_key)
if not entry:
return {"message_count": 0, "max_messages": 0, "remaining": 0,
"near_limit": False, "at_limit": False, "threshold": 0.0}
policy = self.config.get_reset_policy(
platform=entry.platform,
session_type=entry.chat_type,
)
max_msgs = policy.max_messages
count = entry.message_count
remaining = max(0, max_msgs - count) if max_msgs > 0 else float("inf")
threshold = count / max_msgs if max_msgs > 0 else 0.0
return {
"message_count": count,
"max_messages": max_msgs,
"remaining": remaining,
"near_limit": max_msgs > 0 and count >= int(max_msgs * 0.85),
"at_limit": max_msgs > 0 and count >= max_msgs,
"threshold": threshold,
}
def reset_message_count(self, session_key: str) -> None:
"""Reset the message count to zero for a session (#326)."""
with self._lock:
self._ensure_loaded_locked()
entry = self._entries.get(session_key)
if entry:
entry.message_count = 0
if input_tokens is not None:
entry.input_tokens = input_tokens
if output_tokens is not None:
entry.output_tokens = output_tokens
if total_tokens is not None:
entry.total_tokens = total_tokens
if estimated_cost_usd is not None:
entry.estimated_cost_usd = estimated_cost_usd
self._save()
def reset_session(self, session_key: str) -> Optional[SessionEntry]:
@@ -899,7 +861,6 @@ class SessionStore:
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
message_count=0, # Fresh count after rotation (#326)
)
self._entries[session_key] = new_entry
@@ -959,7 +920,6 @@ class SessionStore:
display_name=old_entry.display_name,
platform=old_entry.platform,
chat_type=old_entry.chat_type,
message_count=0, # Fresh count after rotation (#326)
)
self._entries[session_key] = new_entry
@@ -1018,16 +978,6 @@ class SessionStore:
transcript_path = self.get_transcript_path(session_id)
with open(transcript_path, "a", encoding="utf-8") as f:
f.write(json.dumps(message, ensure_ascii=False) + "\n")
# Increment message count for marathon session tracking (#326)
# Skip counting session_meta entries (tool defs, metadata)
if message.get("role") != "session_meta":
with self._lock:
for entry in self._entries.values():
if entry.session_id == session_id:
entry.message_count += 1
self._save()
break
def rewrite_transcript(self, session_id: str, messages: List[Dict[str, Any]]) -> None:
"""Replace the entire transcript for a session with new messages.

View File

@@ -1,184 +0,0 @@
"""Tests for marathon session limits (#326)."""
import pytest
from datetime import datetime
from pathlib import Path
from tempfile import mkdtemp
from gateway.config import GatewayConfig, Platform, SessionResetPolicy
from gateway.session import SessionEntry, SessionSource, SessionStore
def _source(platform=Platform.LOCAL, chat_id="test"):
return SessionSource(platform=platform, chat_id=chat_id, chat_type="dm", user_id="u1")
def _store(max_messages=200, mode="both"):
cfg = GatewayConfig()
cfg.default_reset_policy = SessionResetPolicy(mode=mode, max_messages=max_messages)
return SessionStore(Path(mkdtemp()), cfg)
class TestSessionResetPolicyMaxMessages:
def test_default(self):
assert SessionResetPolicy().max_messages == 200
def test_custom(self):
assert SessionResetPolicy(max_messages=500).max_messages == 500
def test_unlimited(self):
assert SessionResetPolicy(max_messages=0).max_messages == 0
def test_to_dict(self):
d = SessionResetPolicy(max_messages=300).to_dict()
assert d["max_messages"] == 300
def test_from_dict(self):
p = SessionResetPolicy.from_dict({"max_messages": 150})
assert p.max_messages == 150
def test_from_dict_default(self):
assert SessionResetPolicy.from_dict({}).max_messages == 200
class TestSessionEntryMessageCount:
def test_default(self):
e = SessionEntry(session_key="k", session_id="s", created_at=datetime.now(), updated_at=datetime.now())
assert e.message_count == 0
def test_to_dict(self):
e = SessionEntry(session_key="k", session_id="s", created_at=datetime.now(), updated_at=datetime.now(), message_count=42)
assert e.to_dict()["message_count"] == 42
def test_from_dict(self):
e = SessionEntry.from_dict({"session_key": "k", "session_id": "s", "created_at": "2026-01-01T00:00:00", "updated_at": "2026-01-01T00:00:00", "message_count": 99})
assert e.message_count == 99
class TestShouldResetMessageLimit:
def test_at_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
assert s._should_reset(e, src) == "message_limit"
def test_over_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 250
assert s._should_reset(e, src) == "message_limit"
def test_below_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 100
assert s._should_reset(e, src) is None
def test_unlimited(self):
s = _store(max_messages=0, mode="none")
src = _source()
e = s.get_or_create_session(src)
e.message_count = 9999
assert s._should_reset(e, src) is None
def test_custom_limit(self):
s = _store(max_messages=50)
src = _source()
e = s.get_or_create_session(src)
e.message_count = 50
assert s._should_reset(e, src) == "message_limit"
def test_just_under(self):
s = _store(max_messages=50)
src = _source()
e = s.get_or_create_session(src)
e.message_count = 49
assert s._should_reset(e, src) is None
class TestAppendIncrementsCount:
def test_user_message(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "user", "content": "hi"})
e = s.get_or_create_session(src)
assert e.message_count == 1
def test_assistant_message(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "user", "content": "hi"})
s.append_to_transcript(e.session_id, {"role": "assistant", "content": "hello"})
e = s.get_or_create_session(src)
assert e.message_count == 2
def test_meta_not_counted(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
s.append_to_transcript(e.session_id, {"role": "session_meta", "tools": []})
e = s.get_or_create_session(src)
assert e.message_count == 0
class TestGetMessageLimitInfo:
def test_at_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
info = s.get_message_limit_info(e.session_key)
assert info["at_limit"] is True
assert info["near_limit"] is True
assert info["remaining"] == 0
def test_near_limit(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 180
info = s.get_message_limit_info(e.session_key)
assert info["near_limit"] is True
assert info["at_limit"] is False
assert info["remaining"] == 20
def test_well_below(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 50
info = s.get_message_limit_info(e.session_key)
assert info["near_limit"] is False
assert info["at_limit"] is False
def test_unknown(self):
s = _store()
info = s.get_message_limit_info("nonexistent")
assert info["at_limit"] is False
class TestResetMessageCount:
def test_reset(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 150
s.reset_message_count(e.session_key)
assert s.get_message_limit_info(e.session_key)["message_count"] == 0
class TestSessionRotation:
def test_fresh_count_after_reset(self):
s = _store()
src = _source()
e = s.get_or_create_session(src)
e.message_count = 200
new = s.reset_session(e.session_key)
assert new is not None
assert new.message_count == 0
assert new.session_id != e.session_id

View File

@@ -0,0 +1,107 @@
"""Tests for gateway token count persistence to SessionEntry and SessionDB.
Regression test for #316 — token tracking all zeros. The gateway must
propagate input_tokens / output_tokens from the agent result to both the
SessionEntry (sessions.json) and the SQLite session DB.
"""
import json
from datetime import datetime
from unittest.mock import MagicMock
import pytest
from gateway.session import SessionEntry
class TestUpdateSessionTokenFields:
"""Verify SessionEntry token fields are updated and serialized correctly."""
def test_session_entry_to_dict_includes_tokens(self):
entry = SessionEntry(
session_key="tg:123",
session_id="sid-1",
created_at=datetime.now(),
updated_at=datetime.now(),
input_tokens=1000,
output_tokens=500,
total_tokens=1500,
estimated_cost_usd=0.05,
)
d = entry.to_dict()
assert d["input_tokens"] == 1000
assert d["output_tokens"] == 500
assert d["total_tokens"] == 1500
assert d["estimated_cost_usd"] == 0.05
def test_session_entry_from_dict_restores_tokens(self):
now = datetime.now().isoformat()
data = {
"session_key": "tg:123",
"session_id": "sid-1",
"created_at": now,
"updated_at": now,
"input_tokens": 42,
"output_tokens": 21,
"total_tokens": 63,
"estimated_cost_usd": 0.001,
}
entry = SessionEntry.from_dict(data)
assert entry.input_tokens == 42
assert entry.output_tokens == 21
assert entry.total_tokens == 63
assert entry.estimated_cost_usd == 0.001
def test_session_entry_roundtrip_preserves_tokens(self):
"""to_dict -> from_dict must preserve all token fields."""
entry = SessionEntry(
session_key="cron:job7",
session_id="sid-7",
created_at=datetime.now(),
updated_at=datetime.now(),
input_tokens=9999,
output_tokens=1234,
total_tokens=11233,
cache_read_tokens=500,
cache_write_tokens=100,
estimated_cost_usd=0.42,
)
restored = SessionEntry.from_dict(entry.to_dict())
assert restored.input_tokens == 9999
assert restored.output_tokens == 1234
assert restored.total_tokens == 11233
assert restored.cache_read_tokens == 500
assert restored.cache_write_tokens == 100
assert restored.estimated_cost_usd == 0.42
class TestAgentResultTokenExtraction:
"""Verify the gateway extracts token counts from agent_result correctly."""
def test_agent_result_has_expected_keys(self):
"""Simulate what _run_agent returns and verify all token keys exist."""
result = {
"final_response": "hello",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
"cache_read_tokens": 10,
"cache_write_tokens": 5,
"reasoning_tokens": 0,
"estimated_cost_usd": 0.002,
"last_prompt_tokens": 100,
"model": "test-model",
"session_id": "test-session-123",
}
# These are the extractions the gateway performs
assert result.get("input_tokens", 0) or 0 == 100
assert result.get("output_tokens", 0) or 0 == 50
assert result.get("total_tokens", 0) or 0 == 150
assert result.get("estimated_cost_usd") == 0.002
def test_agent_result_zero_fallback(self):
"""When token keys are missing, defaults to 0."""
result = {"final_response": "ok"}
assert result.get("input_tokens", 0) or 0 == 0
assert result.get("output_tokens", 0) or 0 == 0
assert result.get("total_tokens", 0) or 0 == 0