diff --git a/agent/gemma4_tool_normalizer.py b/agent/gemma4_tool_normalizer.py new file mode 100644 index 000000000..28c303f2c --- /dev/null +++ b/agent/gemma4_tool_normalizer.py @@ -0,0 +1,234 @@ +""" +gemma4_tool_normalizer.py — Normalize Gemma 4 tool call output quirks. + +Gemma 4 (and some Ollama models) emit tool calls in formats that differ +from the OpenAI standard: + +1. Extra whitespace around JSON arguments +2. Parallel tool calls split across separate assistant messages +3. Streaming chunks with split JSON + +This module normalizes these into standard OpenAI tool_calls format. + +Usage: + from agent.gemma4_tool_normalizer import normalize_tool_calls, normalize_messages_tool_calls + + # Normalize a single tool call dict + normalized = normalize_tool_calls(raw_tool_calls) + + # Normalize an entire conversation (merges split messages) + messages = normalize_messages_tool_calls(messages) +""" + +import json +import re +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + + +def normalize_tool_call_args(args_str: str) -> str: + """Normalize tool call arguments string. + + Handles Gemma 4 quirks: + - Extra whitespace/newlines around JSON + - Trailing commas + - Single-quoted strings (convert to double) + """ + if not args_str or not isinstance(args_str, str): + return args_str + + # Strip leading/trailing whitespace + args_str = args_str.strip() + + # Remove leading/trailing newlines and excessive whitespace + args_str = re.sub(r'^\s*\n+\s*', '', args_str) + args_str = re.sub(r'\n+\s*$', '', args_str) + + # Remove trailing commas before closing braces/brackets + args_str = re.sub(r',\s*([}\]])', r'\1', args_str) + + # Convert single-quoted values to double-quoted (Gemma 4 quirk) + # Only do this if the string doesn't parse as valid JSON + try: + json.loads(args_str) + return args_str # Already valid + except json.JSONDecodeError: + pass + + # Try fixing single quotes + fixed = re.sub(r"(? dict: + """Normalize a single tool call dict.""" + if not isinstance(tc, dict): + return tc + + tc = tc.copy() + + # Normalize function.arguments + fn = tc.get("function") + if isinstance(fn, dict): + fn = fn.copy() + args = fn.get("arguments") + if isinstance(args, str): + fn["arguments"] = normalize_tool_call_args(args) + tc["function"] = fn + + # Ensure id exists + if "id" not in tc: + tc["id"] = f"call_{hash(str(tc)) % 10**10:010d}" + + return tc + + +def normalize_tool_calls(tool_calls: List[dict]) -> List[dict]: + """Normalize a list of tool calls.""" + if not tool_calls: + return tool_calls + return [normalize_tool_call(tc) for tc in tool_calls if isinstance(tc, dict)] + + +def merge_split_tool_calls(messages: List[dict]) -> List[dict]: + """Merge consecutive assistant messages with tool_calls into one. + + Gemma 4 sometimes emits parallel tool calls as separate assistant + messages instead of one message with multiple tool_calls. + """ + if not messages: + return messages + + merged = [] + pending_tool_calls = [] + pending_content = [] + + for msg in messages: + if not isinstance(msg, dict): + merged.append(msg) + continue + + role = msg.get("role") + tool_calls = msg.get("tool_calls") + + if role == "assistant" and tool_calls and isinstance(tool_calls, list): + # Accumulate tool calls from split messages + pending_tool_calls.extend(normalize_tool_calls(tool_calls)) + content = msg.get("content", "") + if content: + pending_content.append(content) + else: + # Flush accumulated tool calls + if pending_tool_calls: + merged_msg = { + "role": "assistant", + "content": "\n".join(pending_content) if pending_content else "", + "tool_calls": pending_tool_calls, + } + merged.append(merged_msg) + pending_tool_calls = [] + pending_content = [] + + merged.append(msg) + + # Flush remaining + if pending_tool_calls: + merged_msg = { + "role": "assistant", + "content": "\n".join(pending_content) if pending_content else "", + "tool_calls": pending_tool_calls, + } + merged.append(merged_msg) + + return merged + + +def normalize_messages_tool_calls(messages: List[dict]) -> List[dict]: + """Full normalization pipeline for conversation messages. + + 1. Merge split tool_call messages + 2. Normalize individual tool call arguments + """ + messages = merge_split_tool_calls(messages) + messages = _normalize_tool_calls_in_messages(messages) + return messages + + +def _normalize_tool_calls_in_messages(messages: List[dict]) -> List[dict]: + """Normalize tool_calls within each message.""" + result = [] + for msg in messages: + if not isinstance(msg, dict): + result.append(msg) + continue + msg = msg.copy() + tool_calls = msg.get("tool_calls") + if isinstance(tool_calls, list) and tool_calls: + msg["tool_calls"] = normalize_tool_calls(tool_calls) + result.append(msg) + return result + + +def repair_json_fragment(fragment: str, prefix: str = "") -> Optional[str]: + """Attempt to repair a JSON fragment from streaming. + + Gemma 4 streaming may split JSON across chunks. This attempts to + reassemble valid JSON from fragments. + """ + if not fragment: + return None + + candidate = prefix + fragment + + # Try direct parse + try: + json.loads(candidate) + return candidate + except json.JSONDecodeError: + pass + + # Try closing unclosed braces/brackets + open_braces = candidate.count('{') - candidate.count('}') + open_brackets = candidate.count('[') - candidate.count(']') + + if open_braces > 0 or open_brackets > 0: + repaired = candidate + ('}' * open_braces) + (']' * open_brackets) + try: + json.loads(repaired) + return repaired + except json.JSONDecodeError: + pass + + # Try removing incomplete trailing key/value + for i in range(len(candidate) - 1, max(0, len(candidate) - 50), -1): + if candidate[i] in (',', ':'): + repaired = candidate[:i] + if repaired.endswith(','): + repaired = repaired[:-1] + open_b = repaired.count('{') - repaired.count('}') + open_br = repaired.count('[') - repaired.count(']') + repaired += ('}' * open_b) + (']' * open_br) + try: + json.loads(repaired) + return repaired + except json.JSONDecodeError: + continue + + return None diff --git a/tests/agent/test_gemma4_tool_normalizer.py b/tests/agent/test_gemma4_tool_normalizer.py new file mode 100644 index 000000000..15756ab38 --- /dev/null +++ b/tests/agent/test_gemma4_tool_normalizer.py @@ -0,0 +1,106 @@ +"""Tests for Gemma 4 tool call normalizer.""" + +import json +import pytest + +from agent.gemma4_tool_normalizer import ( + normalize_tool_call_args, + normalize_tool_call, + normalize_tool_calls, + merge_split_tool_calls, + normalize_messages_tool_calls, + repair_json_fragment, +) + + +class TestNormalizeArgs: + def test_strips_whitespace(self): + result = normalize_tool_call_args(' \n {"path": "/tmp"} \n ') + assert json.loads(result) == {"path": "/tmp"} + + def test_removes_trailing_comma(self): + result = normalize_tool_call_args('{"path": "/tmp",}') + assert json.loads(result) == {"path": "/tmp"} + + def test_fixes_single_quotes(self): + result = normalize_tool_call_args("{'path': '/tmp'}") + parsed = json.loads(result) + assert parsed["path"] == "/tmp" + + def test_wraps_bare_kv_pairs(self): + result = normalize_tool_call_args('"path": "/tmp", "mode": "read"') + parsed = json.loads(result) + assert parsed["path"] == "/tmp" + + def test_valid_json_unchanged(self): + original = '{"command": "ls -la"}' + result = normalize_tool_call_args(original) + assert result == original + + def test_empty_string(self): + assert normalize_tool_call_args("") == "" + + def test_none_passthrough(self): + assert normalize_tool_call_args(None) is None + + +class TestNormalizeToolCall: + def test_normalizes_args(self): + tc = { + "id": "call_123", + "function": {"name": "execute_code", "arguments": ' {"code": "print(1)"} '} + } + result = normalize_tool_call(tc) + assert json.loads(result["function"]["arguments"]) == {"code": "print(1)"} + + def test_adds_missing_id(self): + tc = {"function": {"name": "terminal", "arguments": '{"command":"ls"}'}} + result = normalize_tool_call(tc) + assert "id" in result + assert result["id"].startswith("call_") + + +class TestMergeSplitToolCalls: + def test_merges_consecutive_assistant_messages(self): + messages = [ + {"role": "assistant", "content": "", "tool_calls": [{"id": "1", "function": {"name": "read_file", "arguments": '{"path":"a.py"}'}}]}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "2", "function": {"name": "read_file", "arguments": '{"path":"b.py"}'}}]}, + {"role": "tool", "content": "file a content", "tool_call_id": "1"}, + ] + result = merge_split_tool_calls(messages) + # First message should have both tool calls merged + assert len(result[0]["tool_calls"]) == 2 + assert len(result) == 2 # merged assistant + tool response + + def test_non_consecutive_not_merged(self): + messages = [ + {"role": "assistant", "content": "", "tool_calls": [{"id": "1", "function": {"name": "x", "arguments": "{}"}}]}, + {"role": "tool", "content": "result", "tool_call_id": "1"}, + {"role": "assistant", "content": "", "tool_calls": [{"id": "2", "function": {"name": "y", "arguments": "{}"}}]}, + ] + result = merge_split_tool_calls(messages) + assert len(result) == 3 # no merging across tool response + + +class TestRepairJson: + def test_repair_unclosed_brace(self): + result = repair_json_fragment('{"path": "/tmp"') + assert result is not None + assert json.loads(result) == {"path": "/tmp"} + + def test_repair_unclosed_array(self): + result = repair_json_fragment('[1, 2, 3') + assert result is not None + assert json.loads(result) == [1, 2, 3] + + def test_repair_trailing_key(self): + result = repair_json_fragment('{"a": 1, "b"') + assert result is not None + assert json.loads(result) == {"a": 1} + + def test_valid_json_returned_unchanged(self): + original = '{"x": 1}' + assert repair_json_fragment(original) == original + + def test_empty_returns_none(self): + assert repair_json_fragment("") is None