Merge PR #564: fix: count actual tool calls instead of tool-related messages

Authored by 0xbyt4. Fixes tool_call_count double-counting tool responses
and under-counting parallel tool calls.
This commit is contained in:
teknium1
2026-03-09 23:32:54 -07:00
2 changed files with 46 additions and 5 deletions

View File

@@ -490,12 +490,16 @@ class SessionDB:
msg_id = cursor.lastrowid
# Update counters
is_tool_related = role == "tool" or tool_calls is not None
if is_tool_related:
# Count actual tool calls from the tool_calls list (not from tool responses).
# A single assistant message can contain multiple parallel tool calls.
num_tool_calls = 0
if tool_calls is not None:
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
if num_tool_calls > 0:
self._conn.execute(
"""UPDATE sessions SET message_count = message_count + 1,
tool_call_count = tool_call_count + 1 WHERE id = ?""",
(session_id,),
tool_call_count = tool_call_count + ? WHERE id = ?""",
(num_tool_calls, session_id),
)
else:
self._conn.execute(

View File

@@ -94,13 +94,50 @@ class TestMessageStorage:
session = db.get_session("s1")
assert session["message_count"] == 2
def test_tool_message_increments_tool_count(self, db):
def test_tool_response_does_not_increment_tool_count(self, db):
"""Tool responses (role=tool) should not increment tool_call_count.
Only assistant messages with tool_calls should count.
"""
db.create_session(session_id="s1", source="cli")
db.append_message("s1", role="tool", content="result", tool_name="web_search")
session = db.get_session("s1")
assert session["tool_call_count"] == 0
def test_assistant_tool_calls_increment_by_count(self, db):
"""An assistant message with N tool_calls should increment by N."""
db.create_session(session_id="s1", source="cli")
tool_calls = [
{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}},
]
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
session = db.get_session("s1")
assert session["tool_call_count"] == 1
def test_tool_call_count_matches_actual_calls(self, db):
"""tool_call_count should equal the number of tool calls made, not messages."""
db.create_session(session_id="s1", source="cli")
# Assistant makes 2 parallel tool calls in one message
tool_calls = [
{"id": "call_1", "function": {"name": "ha_call_service", "arguments": "{}"}},
{"id": "call_2", "function": {"name": "ha_call_service", "arguments": "{}"}},
]
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
# Two tool responses come back
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
session = db.get_session("s1")
# Should be 2 (the actual number of tool calls), not 3
assert session["tool_call_count"] == 2, (
f"Expected 2 tool calls but got {session['tool_call_count']}. "
"tool responses are double-counted and multi-call messages are under-counted"
)
def test_tool_calls_serialization(self, db):
db.create_session(session_id="s1", source="cli")
tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]