diff --git a/environments/tool_call_parsers/mistral_parser.py b/environments/tool_call_parsers/mistral_parser.py index 5526bdd01..50e98a6f8 100644 --- a/environments/tool_call_parsers/mistral_parser.py +++ b/environments/tool_call_parsers/mistral_parser.py @@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models. """ import json -import re import uuid from typing import List, Optional @@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser): # The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer BOT_TOKEN = "[TOOL_CALLS]" - # Fallback regex for pre-v11 format when JSON parsing fails - TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL) - def parse(self, text: str) -> ParseResult: if self.BOT_TOKEN not in text: return text, None @@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser): tool_name = raw[:brace_idx].strip() args_str = raw[brace_idx:] + # Validate and clean the JSON arguments + try: + parsed_args = json.loads(args_str) + args_str = json.dumps(parsed_args, ensure_ascii=False) + except json.JSONDecodeError: + pass # Keep raw if parsing fails + tool_calls.append( ChatCompletionMessageToolCall( id=_generate_mistral_id(), @@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser): ) ) except json.JSONDecodeError: - # Fallback regex extraction - match = self.TOOL_CALL_REGEX.findall(first_raw) - if match: - for raw_json in match: - try: - tc = json.loads(raw_json) - args = tc.get("arguments", {}) + # Fallback: extract JSON objects using raw_decode + decoder = json.JSONDecoder() + idx = 0 + while idx < len(first_raw): + try: + obj, end_idx = decoder.raw_decode(first_raw, idx) + if isinstance(obj, dict) and "name" in obj: + args = obj.get("arguments", {}) if isinstance(args, dict): args = json.dumps(args, ensure_ascii=False) tool_calls.append( @@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser): id=_generate_mistral_id(), type="function", function=Function( - name=tc["name"], arguments=args + name=obj["name"], arguments=args ), ) ) - except (json.JSONDecodeError, KeyError): - continue + idx = end_idx + except json.JSONDecodeError: + idx += 1 if not tool_calls: return text, None diff --git a/tests/test_tool_call_parsers.py b/tests/test_tool_call_parsers.py index 937463422..bdea75698 100644 --- a/tests/test_tool_call_parsers.py +++ b/tests/test_tool_call_parsers.py @@ -209,3 +209,66 @@ class TestDeepSeekV3Parser: content, tool_calls = parser.parse(text) assert tool_calls is not None assert len(tool_calls) == 1 + + +# ─── Mistral parser tests ─────────────────────────────────────────────── + +class TestMistralParser: + @pytest.fixture + def parser(self): + return get_parser("mistral") + + def test_no_tool_call(self, parser): + text = "Hello, how can I help you?" + content, tool_calls = parser.parse(text) + assert content == text + assert tool_calls is None + + def test_pre_v11_single_tool_call(self, parser): + text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "func" + args = json.loads(tool_calls[0].function.arguments) + assert args["key"] == "val" + + def test_pre_v11_nested_json(self, parser): + text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "func" + args = json.loads(tool_calls[0].function.arguments) + assert args["nested"]["deep"] is True + + def test_v11_single_tool_call(self, parser): + text = '[TOOL_CALLS]get_weather{"city": "London"}' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "get_weather" + args = json.loads(tool_calls[0].function.arguments) + assert args["city"] == "London" + + def test_v11_multiple_tool_calls(self, parser): + text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 2 + names = [tc.function.name for tc in tool_calls] + assert "func1" in names + assert "func2" in names + + def test_preceding_text_preserved(self, parser): + text = 'Hello[TOOL_CALLS]func{"a": 1}' + content, tool_calls = parser.parse(text) + assert content == "Hello" + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "func" + + def test_malformed_json_fallback(self, parser): + text = "[TOOL_CALLS] not valid json" + content, tool_calls = parser.parse(text) + assert tool_calls is None