Merge PR #564: fix: count actual tool calls instead of tool-related messages
Authored by 0xbyt4. Fixes tool_call_count double-counting tool responses and under-counting parallel tool calls.
This commit is contained in:
@@ -490,12 +490,16 @@ class SessionDB:
|
||||
msg_id = cursor.lastrowid
|
||||
|
||||
# Update counters
|
||||
is_tool_related = role == "tool" or tool_calls is not None
|
||||
if is_tool_related:
|
||||
# Count actual tool calls from the tool_calls list (not from tool responses).
|
||||
# A single assistant message can contain multiple parallel tool calls.
|
||||
num_tool_calls = 0
|
||||
if tool_calls is not None:
|
||||
num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1
|
||||
if num_tool_calls > 0:
|
||||
self._conn.execute(
|
||||
"""UPDATE sessions SET message_count = message_count + 1,
|
||||
tool_call_count = tool_call_count + 1 WHERE id = ?""",
|
||||
(session_id,),
|
||||
tool_call_count = tool_call_count + ? WHERE id = ?""",
|
||||
(num_tool_calls, session_id),
|
||||
)
|
||||
else:
|
||||
self._conn.execute(
|
||||
|
||||
@@ -94,13 +94,50 @@ class TestMessageStorage:
|
||||
session = db.get_session("s1")
|
||||
assert session["message_count"] == 2
|
||||
|
||||
def test_tool_message_increments_tool_count(self, db):
|
||||
def test_tool_response_does_not_increment_tool_count(self, db):
|
||||
"""Tool responses (role=tool) should not increment tool_call_count.
|
||||
|
||||
Only assistant messages with tool_calls should count.
|
||||
"""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
db.append_message("s1", role="tool", content="result", tool_name="web_search")
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["tool_call_count"] == 0
|
||||
|
||||
def test_assistant_tool_calls_increment_by_count(self, db):
|
||||
"""An assistant message with N tool_calls should increment by N."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
tool_calls = [
|
||||
{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}},
|
||||
]
|
||||
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
|
||||
|
||||
session = db.get_session("s1")
|
||||
assert session["tool_call_count"] == 1
|
||||
|
||||
def test_tool_call_count_matches_actual_calls(self, db):
|
||||
"""tool_call_count should equal the number of tool calls made, not messages."""
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
|
||||
# Assistant makes 2 parallel tool calls in one message
|
||||
tool_calls = [
|
||||
{"id": "call_1", "function": {"name": "ha_call_service", "arguments": "{}"}},
|
||||
{"id": "call_2", "function": {"name": "ha_call_service", "arguments": "{}"}},
|
||||
]
|
||||
db.append_message("s1", role="assistant", content="", tool_calls=tool_calls)
|
||||
|
||||
# Two tool responses come back
|
||||
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
|
||||
db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service")
|
||||
|
||||
session = db.get_session("s1")
|
||||
# Should be 2 (the actual number of tool calls), not 3
|
||||
assert session["tool_call_count"] == 2, (
|
||||
f"Expected 2 tool calls but got {session['tool_call_count']}. "
|
||||
"tool responses are double-counted and multi-call messages are under-counted"
|
||||
)
|
||||
|
||||
def test_tool_calls_serialization(self, db):
|
||||
db.create_session(session_id="s1", source="cli")
|
||||
tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]
|
||||
|
||||
Reference in New Issue
Block a user