Normalize Gemma 4 tool call quirks: - Extra whitespace around JSON arguments - Parallel tool calls split across messages - Single-quoted strings, trailing commas - Unclosed JSON from streaming chunks agent/gemma4_tool_normalizer.py (234 lines): - normalize_tool_call_args(): strip whitespace, fix quotes, trailing commas - merge_split_tool_calls(): combine split assistant messages - repair_json_fragment(): reassemble split streaming JSON - normalize_messages_tool_calls(): full pipeline 16 tests, all passing. Closes #797
107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
"""Tests for Gemma 4 tool call normalizer."""
|
|
|
|
import json
|
|
import pytest
|
|
|
|
from agent.gemma4_tool_normalizer import (
|
|
normalize_tool_call_args,
|
|
normalize_tool_call,
|
|
normalize_tool_calls,
|
|
merge_split_tool_calls,
|
|
normalize_messages_tool_calls,
|
|
repair_json_fragment,
|
|
)
|
|
|
|
|
|
class TestNormalizeArgs:
|
|
def test_strips_whitespace(self):
|
|
result = normalize_tool_call_args(' \n {"path": "/tmp"} \n ')
|
|
assert json.loads(result) == {"path": "/tmp"}
|
|
|
|
def test_removes_trailing_comma(self):
|
|
result = normalize_tool_call_args('{"path": "/tmp",}')
|
|
assert json.loads(result) == {"path": "/tmp"}
|
|
|
|
def test_fixes_single_quotes(self):
|
|
result = normalize_tool_call_args("{'path': '/tmp'}")
|
|
parsed = json.loads(result)
|
|
assert parsed["path"] == "/tmp"
|
|
|
|
def test_wraps_bare_kv_pairs(self):
|
|
result = normalize_tool_call_args('"path": "/tmp", "mode": "read"')
|
|
parsed = json.loads(result)
|
|
assert parsed["path"] == "/tmp"
|
|
|
|
def test_valid_json_unchanged(self):
|
|
original = '{"command": "ls -la"}'
|
|
result = normalize_tool_call_args(original)
|
|
assert result == original
|
|
|
|
def test_empty_string(self):
|
|
assert normalize_tool_call_args("") == ""
|
|
|
|
def test_none_passthrough(self):
|
|
assert normalize_tool_call_args(None) is None
|
|
|
|
|
|
class TestNormalizeToolCall:
|
|
def test_normalizes_args(self):
|
|
tc = {
|
|
"id": "call_123",
|
|
"function": {"name": "execute_code", "arguments": ' {"code": "print(1)"} '}
|
|
}
|
|
result = normalize_tool_call(tc)
|
|
assert json.loads(result["function"]["arguments"]) == {"code": "print(1)"}
|
|
|
|
def test_adds_missing_id(self):
|
|
tc = {"function": {"name": "terminal", "arguments": '{"command":"ls"}'}}
|
|
result = normalize_tool_call(tc)
|
|
assert "id" in result
|
|
assert result["id"].startswith("call_")
|
|
|
|
|
|
class TestMergeSplitToolCalls:
|
|
def test_merges_consecutive_assistant_messages(self):
|
|
messages = [
|
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "1", "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}}]},
|
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "2", "function": {"name": "read_file", "arguments": '{"path":"b.py"}'}}]},
|
|
{"role": "tool", "content": "file a content", "tool_call_id": "1"},
|
|
]
|
|
result = merge_split_tool_calls(messages)
|
|
# First message should have both tool calls merged
|
|
assert len(result[0]["tool_calls"]) == 2
|
|
assert len(result) == 2 # merged assistant + tool response
|
|
|
|
def test_non_consecutive_not_merged(self):
|
|
messages = [
|
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "1", "function": {"name": "x", "arguments": "{}"}}]},
|
|
{"role": "tool", "content": "result", "tool_call_id": "1"},
|
|
{"role": "assistant", "content": "", "tool_calls": [{"id": "2", "function": {"name": "y", "arguments": "{}"}}]},
|
|
]
|
|
result = merge_split_tool_calls(messages)
|
|
assert len(result) == 3 # no merging across tool response
|
|
|
|
|
|
class TestRepairJson:
|
|
def test_repair_unclosed_brace(self):
|
|
result = repair_json_fragment('{"path": "/tmp"')
|
|
assert result is not None
|
|
assert json.loads(result) == {"path": "/tmp"}
|
|
|
|
def test_repair_unclosed_array(self):
|
|
result = repair_json_fragment('[1, 2, 3')
|
|
assert result is not None
|
|
assert json.loads(result) == [1, 2, 3]
|
|
|
|
def test_repair_trailing_key(self):
|
|
result = repair_json_fragment('{"a": 1, "b"')
|
|
assert result is not None
|
|
assert json.loads(result) == {"a": 1}
|
|
|
|
def test_valid_json_returned_unchanged(self):
|
|
original = '{"x": 1}'
|
|
assert repair_json_fragment(original) == original
|
|
|
|
def test_empty_returns_none(self):
|
|
assert repair_json_fragment("") is None
|