diff --git a/hermes_state.py b/hermes_state.py index 1d1f951c0..5864cbcff 100644 --- a/hermes_state.py +++ b/hermes_state.py @@ -259,12 +259,16 @@ class SessionDB: msg_id = cursor.lastrowid # Update counters - is_tool_related = role == "tool" or tool_calls is not None - if is_tool_related: + # Count actual tool calls from the tool_calls list (not from tool responses). + # A single assistant message can contain multiple parallel tool calls. + num_tool_calls = 0 + if tool_calls is not None: + num_tool_calls = len(tool_calls) if isinstance(tool_calls, list) else 1 + if num_tool_calls > 0: self._conn.execute( """UPDATE sessions SET message_count = message_count + 1, - tool_call_count = tool_call_count + 1 WHERE id = ?""", - (session_id,), + tool_call_count = tool_call_count + ? WHERE id = ?""", + (num_tool_calls, session_id), ) else: self._conn.execute( diff --git a/tests/test_hermes_state.py b/tests/test_hermes_state.py index 734db494f..de2e05e52 100644 --- a/tests/test_hermes_state.py +++ b/tests/test_hermes_state.py @@ -94,13 +94,50 @@ class TestMessageStorage: session = db.get_session("s1") assert session["message_count"] == 2 - def test_tool_message_increments_tool_count(self, db): + def test_tool_response_does_not_increment_tool_count(self, db): + """Tool responses (role=tool) should not increment tool_call_count. + + Only assistant messages with tool_calls should count. + """ db.create_session(session_id="s1", source="cli") db.append_message("s1", role="tool", content="result", tool_name="web_search") + session = db.get_session("s1") + assert session["tool_call_count"] == 0 + + def test_assistant_tool_calls_increment_by_count(self, db): + """An assistant message with N tool_calls should increment by N.""" + db.create_session(session_id="s1", source="cli") + tool_calls = [ + {"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}, + ] + db.append_message("s1", role="assistant", content="", tool_calls=tool_calls) + session = db.get_session("s1") assert session["tool_call_count"] == 1 + def test_tool_call_count_matches_actual_calls(self, db): + """tool_call_count should equal the number of tool calls made, not messages.""" + db.create_session(session_id="s1", source="cli") + + # Assistant makes 2 parallel tool calls in one message + tool_calls = [ + {"id": "call_1", "function": {"name": "ha_call_service", "arguments": "{}"}}, + {"id": "call_2", "function": {"name": "ha_call_service", "arguments": "{}"}}, + ] + db.append_message("s1", role="assistant", content="", tool_calls=tool_calls) + + # Two tool responses come back + db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service") + db.append_message("s1", role="tool", content="ok", tool_name="ha_call_service") + + session = db.get_session("s1") + # Should be 2 (the actual number of tool calls), not 3 + assert session["tool_call_count"] == 2, ( + f"Expected 2 tool calls but got {session['tool_call_count']}. " + "tool responses are double-counted and multi-call messages are under-counted" + ) + def test_tool_calls_serialization(self, db): db.create_session(session_id="s1", source="cli") tool_calls = [{"id": "call_1", "function": {"name": "web_search", "arguments": "{}"}}]