diff --git a/run_agent.py b/run_agent.py index f138bdcc5..bde681eb4 100644 --- a/run_agent.py +++ b/run_agent.py @@ -1442,6 +1442,34 @@ class AIAgent: return "\n\n".join(prompt_parts) + def _repair_tool_call(self, tool_name: str) -> str | None: + """Attempt to repair a mismatched tool name before aborting. + + 1. Try lowercase + 2. Try normalized (lowercase + hyphens/spaces -> underscores) + 3. Try fuzzy match (difflib, cutoff=0.7) + + Returns the repaired name if found in valid_tool_names, else None. + """ + from difflib import get_close_matches + + # 1. Lowercase + lowered = tool_name.lower() + if lowered in self.valid_tool_names: + return lowered + + # 2. Normalize + normalized = lowered.replace("-", "_").replace(" ", "_") + if normalized in self.valid_tool_names: + return normalized + + # 3. Fuzzy match + matches = get_close_matches(lowered, self.valid_tool_names, n=1, cutoff=0.7) + if matches: + return matches[0] + + return None + def _invalidate_system_prompt(self): """ Invalidate the cached system prompt, forcing a rebuild on the next turn. @@ -4067,39 +4095,37 @@ class AIAgent: logging.debug(f"Tool call: {tc.function.name} with args: {tc.function.arguments[:200]}...") # Validate tool call names - detect model hallucinations + # Repair mismatched tool names before validating + for tc in assistant_message.tool_calls: + if tc.function.name not in self.valid_tool_names: + repaired = self._repair_tool_call(tc.function.name) + if repaired: + print(f"{self.log_prefix}🔧 Auto-repaired tool name: '{tc.function.name}' -> '{repaired}'") + tc.function.name = repaired invalid_tool_calls = [ - tc.function.name for tc in assistant_message.tool_calls + tc.function.name for tc in assistant_message.tool_calls if tc.function.name not in self.valid_tool_names ] - if invalid_tool_calls: - # Track retries for invalid tool calls - if not hasattr(self, '_invalid_tool_retries'): - self._invalid_tool_retries = 0 - self._invalid_tool_retries += 1 - - invalid_preview = invalid_tool_calls[0][:80] + "..." if len(invalid_tool_calls[0]) > 80 else invalid_tool_calls[0] - print(f"{self.log_prefix}⚠️ Invalid tool call detected: '{invalid_preview}'") - print(f"{self.log_prefix} Valid tools: {sorted(self.valid_tool_names)}") - - if self._invalid_tool_retries < 3: - print(f"{self.log_prefix}🔄 Retrying API call ({self._invalid_tool_retries}/3)...") - # Don't add anything to messages, just retry the API call - continue - else: - print(f"{self.log_prefix}❌ Max retries (3) for invalid tool calls exceeded. Stopping as partial.") - # Return partial result - don't include the bad tool call in messages - self._invalid_tool_retries = 0 - self._persist_session(messages, conversation_history) - return { - "final_response": None, - "messages": messages, - "api_calls": api_call_count, - "completed": False, - "partial": True, - "error": f"Model generated invalid tool call: {invalid_preview}" - } - + # Return helpful error to model — model can self-correct next turn + available = ", ".join(sorted(self.valid_tool_names)) + invalid_name = invalid_tool_calls[0] + invalid_preview = invalid_name[:80] + "..." if len(invalid_name) > 80 else invalid_name + print(f"{self.log_prefix}⚠️ Unknown tool '{invalid_preview}' — sending error to model for self-correction") + assistant_msg = self._build_assistant_message(assistant_message, finish_reason) + messages.append(assistant_msg) + self._log_msg_to_db(assistant_msg) + for tc in assistant_message.tool_calls: + if tc.function.name not in self.valid_tool_names: + content = f"Tool '{tc.function.name}' does not exist. Available tools: {available}" + else: + content = f"Skipped: another tool call in this turn used an invalid name. Please retry this tool call." + messages.append({ + "role": "tool", + "tool_call_id": tc.id, + "content": content, + }) + continue # Reset retry counter on successful tool call validation if hasattr(self, '_invalid_tool_retries'): self._invalid_tool_retries = 0