Merge PR #701: fix: tool call repair — auto-lowercase, fuzzy match, helpful error on unknown tool

Authored by teyrebaz33. Adds _repair_tool_call() method: tries lowercase,
normalize (hyphens/spaces → underscores), then fuzzy match (difflib, 0.7
cutoff). Replaces hard abort after 3 retries with graceful error message
sent back to model for self-correction. Fixed bug where valid tool calls
in a mixed batch would get no results (now all get results).
Fixes #520.
This commit is contained in:
teknium1
2026-03-10 06:54:17 -07:00

View File

@@ -1442,6 +1442,34 @@ class AIAgent:
return "\n\n".join(prompt_parts)
def _repair_tool_call(self, tool_name: str) -> str | None:
"""Attempt to repair a mismatched tool name before aborting.
1. Try lowercase
2. Try normalized (lowercase + hyphens/spaces -> underscores)
3. Try fuzzy match (difflib, cutoff=0.7)
Returns the repaired name if found in valid_tool_names, else None.
"""
from difflib import get_close_matches
# 1. Lowercase
lowered = tool_name.lower()
if lowered in self.valid_tool_names:
return lowered
# 2. Normalize
normalized = lowered.replace("-", "_").replace(" ", "_")
if normalized in self.valid_tool_names:
return normalized
# 3. Fuzzy match
matches = get_close_matches(lowered, self.valid_tool_names, n=1, cutoff=0.7)
if matches:
return matches[0]
return None
def _invalidate_system_prompt(self):
"""
Invalidate the cached system prompt, forcing a rebuild on the next turn.
@@ -4067,39 +4095,37 @@ class AIAgent:
logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...")
# Validate tool call names - detect model hallucinations
# Repair mismatched tool names before validating
for tc in assistant_message.tool_calls:
if tc.function.name not in self.valid_tool_names:
repaired = self._repair_tool_call(tc.function.name)
if repaired:
print(f"{self.log_prefix}🔧 Auto-repaired tool name: '{tc.function.name}' -> '{repaired}'")
tc.function.name = repaired
invalid_tool_calls = [
tc.function.name for tc in assistant_message.tool_calls
tc.function.name for tc in assistant_message.tool_calls
if tc.function.name not in self.valid_tool_names
]
if invalid_tool_calls:
# Track retries for invalid tool calls
if not hasattr(self, '_invalid_tool_retries'):
self._invalid_tool_retries = 0
self._invalid_tool_retries += 1
invalid_preview = invalid_tool_calls[0][:80] + "..." if len(invalid_tool_calls[0]) > 80 else invalid_tool_calls[0]
print(f"{self.log_prefix}⚠️ Invalid tool call detected: '{invalid_preview}'")
print(f"{self.log_prefix} Valid tools: {sorted(self.valid_tool_names)}")
if self._invalid_tool_retries < 3:
print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_tool_retries}/3)...")
# Don't add anything to messages, just retry the API call
continue
else:
print(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.")
# Return partial result - don't include the bad tool call in messages
self._invalid_tool_retries = 0
self._persist_session(messages, conversation_history)
return {
"final_response": None,
"messages": messages,
"api_calls": api_call_count,
"completed": False,
"partial": True,
"error": f"Model generated invalid tool call: {invalid_preview}"
}
# Return helpful error to model — model can self-correct next turn
available = ", ".join(sorted(self.valid_tool_names))
invalid_name = invalid_tool_calls[0]
invalid_preview = invalid_name[:80] + "..." if len(invalid_name) > 80 else invalid_name
print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction")
assistant_msg = self._build_assistant_message(assistant_message, finish_reason)
messages.append(assistant_msg)
self._log_msg_to_db(assistant_msg)
for tc in assistant_message.tool_calls:
if tc.function.name not in self.valid_tool_names:
content = f"Tool '{tc.function.name}' does not exist. Available tools: {available}"
else:
content = f"Skipped: another tool call in this turn used an invalid name. Please retry this tool call."
messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": content,
})
continue
# Reset retry counter on successful tool call validation
if hasattr(self, '_invalid_tool_retries'):
self._invalid_tool_retries = 0