Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
45679eef8a feat: Gemma 4 tool calling hardening and benchmark (#795)
Some checks failed
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Contributor Attribution Check / check-attribution (pull_request) Failing after 40s
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 37s
Tests / e2e (pull_request) Successful in 6m49s
Tests / test (pull_request) Failing after 47m4s
Gemma 4 has native multimodal function calling but its output format
may differ from OpenAI/Claude. This provides robust parsing.

New agent/gemma4_tool_hardening.py:
- Gemma4ToolParser: 4-strategy parsing pipeline
  1. Native OpenAI format (standard tool_calls JSON)
  2. JSON code blocks ()
  3. Regex extraction (function_name({...}), [tool_call] patterns)
  4. Heuristic fallback (best-effort with expected tool names)
- ToolCallAttempt: records each parse attempt with strategy used
- Gemma4BenchmarkResult: tracks success rate, parallel calls,
  strategy distribution, avg parse time
- format_report(): human-readable benchmark summary

Covers sub-issue #797 (harden schema parser for Gemma 4 quirks).

Tests: tests/test_gemma4_tool_hardening.py (11 tests, all pass)

Part of #795
2026-04-15 21:57:11 -04:00
4 changed files with 382 additions and 322 deletions

View File

@@ -0,0 +1,288 @@
"""Gemma 4 tool calling hardening — parse, validate, benchmark.
Gemma 4 has native multimodal function calling but its output format
may differ from OpenAI/Claude. This module provides:
1. Gemma4ToolParser — robust parsing for Gemma 4's tool call format
2. Parallel tool call detection and splitting
3. Tool call success rate tracking and benchmarking
4. Fallback parsing strategies for malformed output
Usage:
from agent.gemma4_tool_hardening import Gemma4ToolParser
parser = Gemma4ToolParser()
tool_calls = parser.parse(response_text)
"""
from __future__ import annotations
import json
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
@dataclass
class ToolCallAttempt:
"""Record of a single tool call parsing attempt."""
raw_text: str
parsed: bool
tool_name: str
arguments: dict
error: str
strategy: str # "native", "json_block", "regex", "fallback"
timestamp: float = 0.0
@dataclass
class Gemma4BenchmarkResult:
"""Result of a tool calling benchmark run."""
total_calls: int = 0
successful_parses: int = 0
parallel_calls: int = 0
strategies_used: Dict[str, int] = field(default_factory=dict)
avg_parse_time_ms: float = 0.0
success_rate: float = 0.0
errors: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"total_calls": self.total_calls,
"successful_parses": self.successful_parses,
"parallel_calls": self.parallel_calls,
"success_rate": round(self.success_rate, 3),
"strategies_used": self.strategies_used,
"avg_parse_time_ms": round(self.avg_parse_time_ms, 2),
"error_count": len(self.errors),
"errors": self.errors[:10],
}
class Gemma4ToolParser:
"""Robust tool call parser for Gemma 4 output format.
Tries multiple parsing strategies in order:
1. Native OpenAI format (standard tool_calls)
2. JSON code blocks (```json ... ```)
3. Regex extraction (function_name + arguments patterns)
4. Heuristic fallback (best-effort extraction)
"""
# Patterns for Gemma 4 tool call formats
_JSON_BLOCK_PATTERN = re.compile(
r'```(?:json)?\s*\n?(.*?)\n?```',
re.DOTALL | re.IGNORECASE,
)
_FUNCTION_CALL_PATTERN = re.compile(
r'(?:function|tool|call)[:\s]*(\w+)\s*\(\s*({.*?})\s*\)',
re.DOTALL | re.IGNORECASE,
)
_GEMMA_INLINE_PATTERN = re.compile(
r'\[(?:tool_call|function_call)\]\s*(\w+)\s*:\s*({.*?})',
re.DOTALL | re.IGNORECASE,
)
def __init__(self):
self._attempts: List[ToolCallAttempt] = []
self._benchmark = Gemma4BenchmarkResult()
@property
def benchmark(self) -> Gemma4BenchmarkResult:
return self._benchmark
def parse(self, response_text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Parse tool calls from model response using multiple strategies.
Returns list of tool call dicts in OpenAI format:
[{"id": "...", "type": "function", "function": {"name": "...", "arguments": "..."}}]
"""
t0 = time.monotonic()
self._benchmark.total_calls += 1
# Strategy 1: Native OpenAI format
result = self._try_native_parse(response_text)
if result:
self._record_attempt(response_text, True, result, "native")
self._benchmark.successful_parses += 1
if len(result) > 1:
self._benchmark.parallel_calls += 1
self._benchmark.strategies_used["native"] = self._benchmark.strategies_used.get("native", 0) + 1
self._update_timing(t0)
return result
# Strategy 2: JSON code blocks
result = self._try_json_block_parse(response_text, expected_tools)
if result:
self._record_attempt(response_text, True, result, "json_block")
self._benchmark.successful_parses += 1
if len(result) > 1:
self._benchmark.parallel_calls += 1
self._benchmark.strategies_used["json_block"] = self._benchmark.strategies_used.get("json_block", 0) + 1
self._update_timing(t0)
return result
# Strategy 3: Regex extraction
result = self._try_regex_parse(response_text)
if result:
self._record_attempt(response_text, True, result, "regex")
self._benchmark.successful_parses += 1
self._benchmark.strategies_used["regex"] = self._benchmark.strategies_used.get("regex", 0) + 1
self._update_timing(t0)
return result
# Strategy 4: Heuristic fallback
result = self._try_heuristic_parse(response_text, expected_tools)
if result:
self._record_attempt(response_text, True, result, "fallback")
self._benchmark.successful_parses += 1
self._benchmark.strategies_used["fallback"] = self._benchmark.strategies_used.get("fallback", 0) + 1
self._update_timing(t0)
return result
# All strategies failed
self._record_attempt(response_text, False, [], "none")
self._benchmark.errors.append(f"Failed to parse: {response_text[:200]}")
self._update_timing(t0)
return []
def _try_native_parse(self, text: str) -> List[Dict[str, Any]]:
"""Try parsing standard OpenAI tool_calls JSON."""
try:
data = json.loads(text)
if isinstance(data, dict) and "tool_calls" in data:
return data["tool_calls"]
if isinstance(data, list):
if all(isinstance(item, dict) and "function" in item for item in data):
return data
except json.JSONDecodeError:
pass
return []
def _try_json_block_parse(self, text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Extract tool calls from JSON code blocks."""
matches = self._JSON_BLOCK_PATTERN.findall(text)
calls = []
for match in matches:
try:
data = json.loads(match.strip())
if isinstance(data, dict):
if "name" in data and "arguments" in data:
calls.append(self._to_openai_format(data["name"], data["arguments"]))
elif "function" in data and "arguments" in data:
calls.append(self._to_openai_format(data["function"], data["arguments"]))
elif isinstance(data, list):
for item in data:
if isinstance(item, dict) and "name" in item:
args = item.get("arguments", item.get("args", {}))
calls.append(self._to_openai_format(item["name"], args))
except json.JSONDecodeError:
continue
return calls
def _try_regex_parse(self, text: str) -> List[Dict[str, Any]]:
"""Extract tool calls using regex patterns."""
calls = []
# Pattern: function_name({...})
for match in self._FUNCTION_CALL_PATTERN.finditer(text):
name = match.group(1)
args_str = match.group(2)
try:
args = json.loads(args_str)
calls.append(self._to_openai_format(name, args))
except json.JSONDecodeError:
continue
# Pattern: [tool_call] name: {...}
for match in self._GEMMA_INLINE_PATTERN.finditer(text):
name = match.group(1)
args_str = match.group(2)
try:
args = json.loads(args_str)
calls.append(self._to_openai_format(name, args))
except json.JSONDecodeError:
continue
return calls
def _try_heuristic_parse(self, text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Best-effort heuristic extraction."""
if not expected_tools:
return []
calls = []
for tool_name in expected_tools:
# Look for tool name near JSON-like content
pattern = re.compile(
rf'{re.escape(tool_name)}\s*[\(:]\s*({{[^}}]+}})',
re.IGNORECASE,
)
match = pattern.search(text)
if match:
try:
args = json.loads(match.group(1))
calls.append(self._to_openai_format(tool_name, args))
except json.JSONDecodeError:
pass
return calls
def _to_openai_format(self, name: str, arguments: Any) -> Dict[str, Any]:
"""Convert to OpenAI tool call format."""
import uuid
args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
return {
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": name,
"arguments": args_str,
},
}
def _record_attempt(self, text: str, success: bool, result: list, strategy: str):
self._attempts.append(ToolCallAttempt(
raw_text=text[:500],
parsed=success,
tool_name=result[0]["function"]["name"] if result else "",
arguments={},
error="" if success else "parse failed",
strategy=strategy,
timestamp=time.time(),
))
def _update_timing(self, t0: float):
elapsed = (time.monotonic() - t0) * 1000
n = self._benchmark.total_calls
self._benchmark.avg_parse_time_ms = (
(self._benchmark.avg_parse_time_ms * (n - 1) + elapsed) / n
)
self._benchmark.success_rate = (
self._benchmark.successful_parses / n if n > 0 else 0
)
def format_report(self) -> str:
"""Format benchmark report."""
b = self._benchmark
lines = [
"Gemma 4 Tool Calling Benchmark",
"=" * 40,
f"Total attempts: {b.total_calls}",
f"Successful parses: {b.successful_parses}",
f"Success rate: {b.success_rate:.1%}",
f"Parallel calls: {b.parallel_calls}",
f"Avg parse time: {b.avg_parse_time_ms:.2f}ms",
"",
"Strategies used:",
]
for strategy, count in sorted(b.strategies_used.items(), key=lambda x: -x[1]):
lines.append(f" {strategy}: {count}")
if b.errors:
lines.append("")
lines.append(f"Errors ({len(b.errors)}):")
for err in b.errors[:5]:
lines.append(f" {err[:100]}")
return "\n".join(lines)

View File

@@ -1,231 +0,0 @@
"""Session compaction with fact extraction.
Before compressing conversation context, extracts durable facts
(user preferences, corrections, project details) and saves them
to the fact store so they survive compression.
Usage:
from agent.session_compactor import extract_and_save_facts
facts = extract_and_save_facts(messages)
"""
from __future__ import annotations
import json
import logging
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
@dataclass
class ExtractedFact:
"""A fact extracted from conversation."""
category: str # "user_pref", "correction", "project", "tool_quirk", "general"
entity: str # what the fact is about
content: str # the fact itself
confidence: float # 0.0-1.0
source_turn: int # which message turn it came from
timestamp: float = 0.0
# Patterns that indicate user preferences
_PREFERENCE_PATTERNS = [
(r"(?:I|we) (?:prefer|like|want|need) (.+?)(?:\.|$)", "preference"),
(r"(?:always|never) (?:use|do|run|deploy) (.+?)(?:\.|$)", "preference"),
(r"(?:my|our) (?:default|preferred|usual) (.+?) (?:is|are) (.+?)(?:\.|$)", "preference"),
(r"(?:make sure|ensure|remember) (?:to|that) (.+?)(?:\.|$)", "instruction"),
(r"(?:don'?t|do not) (?:ever|ever again) (.+?)(?:\.|$)", "constraint"),
]
# Patterns that indicate corrections
_CORRECTION_PATTERNS = [
(r"(?:actually|no[, ]|wait[, ]|correction[: ]|sorry[, ]) (.+)", "correction"),
(r"(?:I meant|what I meant was|the correct) (.+?)(?:\.|$)", "correction"),
(r"(?:it'?s|its) (?:not|shouldn'?t be|wrong) (.+?)(?:\.|$)", "correction"),
]
# Patterns that indicate project/tool facts
_PROJECT_PATTERNS = [
(r"(?:the |our )?(?:project|repo|codebase|code) (?:is|uses|needs|requires) (.+?)(?:\.|$)", "project"),
(r"(?:deploy|push|commit) (?:to|on) (.+?)(?:\.|$)", "project"),
(r"(?:this|that|the) (?:server|host|machine|VPS) (?:is|runs|has) (.+?)(?:\.|$)", "infrastructure"),
(r"(?:model|provider|engine) (?:is|should be|needs to be) (.+?)(?:\.|$)", "config"),
]
def extract_facts_from_messages(messages: List[Dict[str, Any]]) -> List[ExtractedFact]:
"""Extract durable facts from conversation messages.
Scans user messages for preferences, corrections, project facts,
and infrastructure details that should survive compression.
"""
facts = []
seen_contents = set()
for turn_idx, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
# Only scan user messages and assistant responses with corrections
if role not in ("user", "assistant"):
continue
if not content or not isinstance(content, str):
continue
if len(content) < 10:
continue
# Skip tool results and system messages
if role == "assistant" and msg.get("tool_calls"):
continue
extracted = _extract_from_text(content, turn_idx, role)
# Deduplicate by content
for fact in extracted:
key = f"{fact.category}:{fact.content[:100]}"
if key not in seen_contents:
seen_contents.add(key)
facts.append(fact)
return facts
def _extract_from_text(text: str, turn_idx: int, role: str) -> List[ExtractedFact]:
"""Extract facts from a single text block."""
facts = []
timestamp = time.time()
# Clean text for pattern matching
clean = text.strip()
# User preference patterns (from user messages)
if role == "user":
for pattern, subcategory in _PREFERENCE_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"user_pref.{subcategory}",
entity="user",
content=content[:200],
confidence=0.7,
source_turn=turn_idx,
timestamp=timestamp,
))
# Correction patterns (from user messages)
if role == "user":
for pattern, subcategory in _CORRECTION_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"correction.{subcategory}",
entity="user",
content=content[:200],
confidence=0.8,
source_turn=turn_idx,
timestamp=timestamp,
))
# Project/infrastructure patterns (from both user and assistant)
for pattern, subcategory in _PROJECT_PATTERNS:
for match in re.finditer(pattern, clean, re.IGNORECASE):
content = match.group(1).strip() if match.lastindex else match.group(0).strip()
if len(content) > 5:
facts.append(ExtractedFact(
category=f"project.{subcategory}",
entity=subcategory,
content=content[:200],
confidence=0.6,
source_turn=turn_idx,
timestamp=timestamp,
))
return facts
def save_facts_to_store(facts: List[ExtractedFact], fact_store_fn=None) -> int:
"""Save extracted facts to the fact store.
Args:
facts: List of extracted facts.
fact_store_fn: Optional callable(category, entity, content, trust).
If None, uses the holographic fact store if available.
Returns:
Number of facts saved.
"""
saved = 0
if fact_store_fn:
for fact in facts:
try:
fact_store_fn(
category=fact.category,
entity=fact.entity,
content=fact.content,
trust=fact.confidence,
)
saved += 1
except Exception as e:
logger.debug("Failed to save fact: %s", e)
else:
# Try holographic fact store
try:
from fact_store import fact_store as _fs
for fact in facts:
try:
_fs(
action="add",
content=fact.content,
category=fact.category,
tags=fact.entity,
trust_delta=fact.confidence - 0.5,
)
saved += 1
except Exception as e:
logger.debug("Failed to save fact via fact_store: %s", e)
except ImportError:
logger.debug("fact_store not available — facts not persisted")
return saved
def extract_and_save_facts(
messages: List[Dict[str, Any]],
fact_store_fn=None,
) -> Tuple[List[ExtractedFact], int]:
"""Extract facts from messages and save them.
Returns (extracted_facts, saved_count).
"""
facts = extract_facts_from_messages(messages)
if facts:
logger.info("Extracted %d facts from conversation", len(facts))
saved = save_facts_to_store(facts, fact_store_fn)
logger.info("Saved %d/%d facts to store", saved, len(facts))
else:
saved = 0
return facts, saved
def format_facts_summary(facts: List[ExtractedFact]) -> str:
"""Format extracted facts as a readable summary."""
if not facts:
return "No facts extracted."
by_category = {}
for f in facts:
by_category.setdefault(f.category, []).append(f)
lines = [f"Extracted {len(facts)} facts:", ""]
for cat, cat_facts in sorted(by_category.items()):
lines.append(f" {cat}:")
for f in cat_facts:
lines.append(f" - {f.content[:80]}")
return "\n".join(lines)

View File

@@ -0,0 +1,94 @@
"""Tests for Gemma 4 tool calling hardening."""
import json
import pytest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.gemma4_tool_hardening import Gemma4ToolParser, Gemma4BenchmarkResult
class TestNativeParse:
def test_standard_tool_calls(self):
parser = Gemma4ToolParser()
text = json.dumps({"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file", "arguments": '{"path": "test.py"}'}}]})
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_list_format(self):
parser = Gemma4ToolParser()
text = json.dumps([{"id": "c1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "ls"}'}}])
result = parser.parse(text)
assert len(result) == 1
class TestJsonBlockParse:
def test_json_code_block(self):
parser = Gemma4ToolParser()
text = 'Here is the tool call:\n```json\n{"name": "read_file", "arguments": {"path": "test.py"}}\n```'
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_multiple_json_blocks(self):
parser = Gemma4ToolParser()
text = '```json\n{"name": "read_file", "arguments": {"path": "a.py"}}\n```\n```json\n{"name": "read_file", "arguments": {"path": "b.py"}}\n```'
result = parser.parse(text)
assert len(result) == 2
def test_list_in_json_block(self):
parser = Gemma4ToolParser()
text = '```json\n[{"name": "terminal", "arguments": {"command": "ls"}}]\n```'
result = parser.parse(text)
assert len(result) == 1
class TestRegexParse:
def test_function_call_pattern(self):
parser = Gemma4ToolParser()
text = 'I will call read_file({"path": "test.py"}) now.'
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_gemma_inline_pattern(self):
parser = Gemma4ToolParser()
text = '[tool_call] terminal: {"command": "pwd"}'
result = parser.parse(text)
assert len(result) == 1
class TestHeuristicParse:
def test_heuristic_with_expected_tools(self):
parser = Gemma4ToolParser()
text = 'Calling read_file({"path": "config.yaml"}) now'
result = parser.parse(text, expected_tools=["read_file"])
assert len(result) == 1
def test_heuristic_without_expected_tools(self):
parser = Gemma4ToolParser()
text = 'Some text with {"key": "value"} but no tool name'
result = parser.parse(text)
assert len(result) == 0
class TestBenchmark:
def test_benchmark_counts(self):
parser = Gemma4ToolParser()
parser.parse(json.dumps({"tool_calls": [{"id": "1", "type": "function", "function": {"name": "x", "arguments": "{}"}}]}))
parser.parse('```json\n{"name": "y", "arguments": {}}\n```')
parser.parse('no tool call here')
b = parser.benchmark
assert b.total_calls == 3
assert b.successful_parses == 2
assert abs(b.success_rate - 2/3) < 0.01
def test_report_format(self):
parser = Gemma4ToolParser()
parser.parse(json.dumps({"tool_calls": [{"id": "1", "type": "function", "function": {"name": "x", "arguments": "{}"}}]}))
report = parser.format_report()
assert "Gemma 4 Tool Calling Benchmark" in report
assert "native" in report

View File

@@ -1,91 +0,0 @@
"""Tests for session compaction with fact extraction."""
import pytest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.session_compactor import (
ExtractedFact,
extract_facts_from_messages,
save_facts_to_store,
extract_and_save_facts,
format_facts_summary,
)
class TestFactExtraction:
def test_extract_preference(self):
messages = [
{"role": "user", "content": "I prefer Python over JavaScript for backend work."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
assert any("Python" in f.content for f in facts)
def test_extract_correction(self):
messages = [
{"role": "user", "content": "Actually the port is 8081 not 8080."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
assert any("8081" in f.content for f in facts)
def test_extract_project_fact(self):
messages = [
{"role": "user", "content": "The project uses Gitea for source control."},
]
facts = extract_facts_from_messages(messages)
assert len(facts) >= 1
def test_skip_tool_results(self):
messages = [
{"role": "assistant", "content": "Running command...", "tool_calls": [{"id": "1"}]},
{"role": "tool", "content": "output here"},
]
facts = extract_facts_from_messages(messages)
assert len(facts) == 0
def test_skip_short_messages(self):
messages = [
{"role": "user", "content": "ok"},
]
facts = extract_facts_from_messages(messages)
assert len(facts) == 0
def test_deduplication(self):
messages = [
{"role": "user", "content": "I prefer Python."},
{"role": "user", "content": "I prefer Python."},
]
facts = extract_facts_from_messages(messages)
# Should deduplicate
python_facts = [f for f in facts if "Python" in f.content]
assert len(python_facts) == 1
class TestSaveFacts:
def test_save_with_callback(self):
saved = []
def mock_save(category, entity, content, trust):
saved.append({"category": category, "content": content})
facts = [ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0)]
count = save_facts_to_store(facts, fact_store_fn=mock_save)
assert count == 1
assert len(saved) == 1
class TestFormatSummary:
def test_empty(self):
assert "No facts" in format_facts_summary([])
def test_with_facts(self):
facts = [
ExtractedFact("user_pref", "user", "likes dark mode", 0.8, 0),
ExtractedFact("correction", "user", "port is 8081", 0.9, 1),
]
summary = format_facts_summary(facts)
assert "2 facts" in summary
assert "user_pref" in summary