fix(mistral-parser): handle nested JSON in fallback extraction

This commit is contained in:
Himess
2026-03-17 15:47:33 +03:00
committed by Teknium
parent 8304a7716d
commit 5663980015
2 changed files with 82 additions and 14 deletions

View File

@@ -10,7 +10,6 @@ The [TOOL_CALLS] token is the bot_token used by Mistral models.
"""
import json
import re
import uuid
from typing import List, Optional
@@ -42,9 +41,6 @@ class MistralToolCallParser(ToolCallParser):
# The [TOOL_CALLS] token -- may appear as different strings depending on tokenizer
BOT_TOKEN = "[TOOL_CALLS]"
# Fallback regex for pre-v11 format when JSON parsing fails
TOOL_CALL_REGEX = re.compile(r"\[?\s*(\{.*?\})\s*\]?", re.DOTALL)
def parse(self, text: str) -> ParseResult:
if self.BOT_TOKEN not in text:
return text, None
@@ -71,6 +67,13 @@ class MistralToolCallParser(ToolCallParser):
tool_name = raw[:brace_idx].strip()
args_str = raw[brace_idx:]
# Validate and clean the JSON arguments
try:
parsed_args = json.loads(args_str)
args_str = json.dumps(parsed_args, ensure_ascii=False)
except json.JSONDecodeError:
pass # Keep raw if parsing fails
tool_calls.append(
ChatCompletionMessageToolCall(
id=_generate_mistral_id(),
@@ -100,13 +103,14 @@ class MistralToolCallParser(ToolCallParser):
)
)
except json.JSONDecodeError:
# Fallback regex extraction
match = self.TOOL_CALL_REGEX.findall(first_raw)
if match:
for raw_json in match:
try:
tc = json.loads(raw_json)
args = tc.get("arguments", {})
# Fallback: extract JSON objects using raw_decode
decoder = json.JSONDecoder()
idx = 0
while idx < len(first_raw):
try:
obj, end_idx = decoder.raw_decode(first_raw, idx)
if isinstance(obj, dict) and "name" in obj:
args = obj.get("arguments", {})
if isinstance(args, dict):
args = json.dumps(args, ensure_ascii=False)
tool_calls.append(
@@ -114,12 +118,13 @@ class MistralToolCallParser(ToolCallParser):
id=_generate_mistral_id(),
type="function",
function=Function(
name=tc["name"], arguments=args
name=obj["name"], arguments=args
),
)
)
except (json.JSONDecodeError, KeyError):
continue
idx = end_idx
except json.JSONDecodeError:
idx += 1
if not tool_calls:
return text, None

View File

@@ -209,3 +209,66 @@ class TestDeepSeekV3Parser:
content, tool_calls = parser.parse(text)
assert tool_calls is not None
assert len(tool_calls) == 1
# ─── Mistral parser tests ───────────────────────────────────────────────
class TestMistralParser:
@pytest.fixture
def parser(self):
return get_parser("mistral")
def test_no_tool_call(self, parser):
text = "Hello, how can I help you?"
content, tool_calls = parser.parse(text)
assert content == text
assert tool_calls is None
def test_pre_v11_single_tool_call(self, parser):
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"key": "val"}}]'
content, tool_calls = parser.parse(text)
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "func"
args = json.loads(tool_calls[0].function.arguments)
assert args["key"] == "val"
def test_pre_v11_nested_json(self, parser):
text = '[TOOL_CALLS] [{"name": "func", "arguments": {"nested": {"deep": true}}}]'
content, tool_calls = parser.parse(text)
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "func"
args = json.loads(tool_calls[0].function.arguments)
assert args["nested"]["deep"] is True
def test_v11_single_tool_call(self, parser):
text = '[TOOL_CALLS]get_weather{"city": "London"}'
content, tool_calls = parser.parse(text)
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "get_weather"
args = json.loads(tool_calls[0].function.arguments)
assert args["city"] == "London"
def test_v11_multiple_tool_calls(self, parser):
text = '[TOOL_CALLS]func1{"a": 1}[TOOL_CALLS]func2{"b": 2}'
content, tool_calls = parser.parse(text)
assert tool_calls is not None
assert len(tool_calls) == 2
names = [tc.function.name for tc in tool_calls]
assert "func1" in names
assert "func2" in names
def test_preceding_text_preserved(self, parser):
text = 'Hello[TOOL_CALLS]func{"a": 1}'
content, tool_calls = parser.parse(text)
assert content == "Hello"
assert tool_calls is not None
assert len(tool_calls) == 1
assert tool_calls[0].function.name == "func"
def test_malformed_json_fallback(self, parser):
text = "[TOOL_CALLS] not valid json"
content, tool_calls = parser.parse(text)
assert tool_calls is None