fix(mistral-parser): handle nested JSON in fallback extraction
This commit is contained in:
@@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
|
||||
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
|
||||
BOT_TOKEN = "[TOOL_CALLS]"
|
||||
|
||||
# Fallback regex for pre-v11 format when JSON parsing fails
|
||||
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
|
||||
|
||||
def parse(self, text: str) -> ParseResult:
|
||||
if self.BOT_TOKEN not in text:
|
||||
return text, None
|
||||
@@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
tool_name = raw[:brace_idx].strip()
|
||||
args_str = raw[brace_idx:]
|
||||
|
||||
# Validate and clean the JSON arguments
|
||||
try:
|
||||
parsed_args = json.loads(args_str)
|
||||
args_str = json.dumps(parsed_args, ensure_ascii=False)
|
||||
except json.JSONDecodeError:
|
||||
pass # Keep raw if parsing fails
|
||||
|
||||
tool_calls.append(
|
||||
ChatCompletionMessageToolCall(
|
||||
id=_generate_mistral_id(),
|
||||
@@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
|
||||
)
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# Fallback regex extraction
|
||||
match = self.TOOL_CALL_REGEX.findall(first_raw)
|
||||
if match:
|
||||
for raw_json in match:
|
||||
try:
|
||||
tc = json.loads(raw_json)
|
||||
args = tc.get("arguments", {})
|
||||
# Fallback: extract JSON objects using raw_decode
|
||||
decoder = json.JSONDecoder()
|
||||
idx = 0
|
||||
while idx < len(first_raw):
|
||||
try:
|
||||
obj, end_idx = decoder.raw_decode(first_raw, idx)
|
||||
if isinstance(obj, dict) and "name" in obj:
|
||||
args = obj.get("arguments", {})
|
||||
if isinstance(args, dict):
|
||||
args = json.dumps(args, ensure_ascii=False)
|
||||
tool_calls.append(
|
||||
@@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
|
||||
id=_generate_mistral_id(),
|
||||
type="function",
|
||||
function=Function(
|
||||
name=tc["name"], arguments=args
|
||||
name=obj["name"], arguments=args
|
||||
),
|
||||
)
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
idx = end_idx
|
||||
except json.JSONDecodeError:
|
||||
idx += 1
|
||||
|
||||
if not tool_calls:
|
||||
return text, None
|
||||
|
||||
@@ -209,3 +209,66 @@ class TestDeepSeekV3Parser:
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
|
||||
# ─── Mistral parser tests ───────────────────────────────────────────────
|
||||
|
||||
class TestMistralParser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("mistral")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, how can I help you?"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_pre_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["key"] == "val"
|
||||
|
||||
def test_pre_v11_nested_json(self, parser):
|
||||
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["nested"]["deep"] is True
|
||||
|
||||
def test_v11_single_tool_call(self, parser):
|
||||
text = '[TOOL_CALLS]get_weather{"city": "London"}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "get_weather"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["city"] == "London"
|
||||
|
||||
def test_v11_multiple_tool_calls(self, parser):
|
||||
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
names = [tc.function.name for tc in tool_calls]
|
||||
assert "func1" in names
|
||||
assert "func2" in names
|
||||
|
||||
def test_preceding_text_preserved(self, parser):
|
||||
text = 'Hello[TOOL_CALLS]func{"a": 1}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == "Hello"
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "func"
|
||||
|
||||
def test_malformed_json_fallback(self, parser):
|
||||
text = "[TOOL_CALLS] not valid json"
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is None
|
||||
|
||||
Reference in New Issue
Block a user