feat: coerce tool call arguments to match JSON Schema types (#5265)

LLMs frequently return numbers as strings ("42" instead of 42) and
booleans as strings ("true" instead of true). This causes silent
failures with MCP tools and any tool with strictly-typed parameters.

Added coerce_tool_args() in model_tools.py that runs before every tool
dispatch. For each argument, it checks the tool registry schema and
attempts safe coercion:
  - "42" → 42 when schema says "type": "integer"
  - "3.14" → 3.14 when schema says "type": "number"
  - "true"/"false" → True/False when schema says "type": "boolean"
  - Union types tried in order
  - Original values preserved when coercion fails or is not applicable

Inspired by Block/goose tool argument coercion system.
This commit is contained in:
Teknium
2026-04-05 10:57:34 -07:00
committed by GitHub
parent e899d6a05d
commit 35d280d0bd
2 changed files with 356 additions and 0 deletions

View File

@@ -365,6 +365,97 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"}
_READ_SEARCH_TOOLS = {"read_file", "search_files"}
# =========================================================================
# Tool argument type coercion
# =========================================================================
def coerce_tool_args(tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]:
"""Coerce tool call arguments to match their JSON Schema types.
LLMs frequently return numbers as strings (``"42"`` instead of ``42``)
and booleans as strings (``"true"`` instead of ``true``). This compares
each argument value against the tool's registered JSON Schema and attempts
safe coercion when the value is a string but the schema expects a different
type. Original values are preserved when coercion fails.
Handles ``"type": "integer"``, ``"type": "number"``, ``"type": "boolean"``,
and union types (``"type": ["integer", "string"]``).
"""
if not args or not isinstance(args, dict):
return args
schema = registry.get_schema(tool_name)
if not schema:
return args
properties = (schema.get("parameters") or {}).get("properties")
if not properties:
return args
for key, value in args.items():
if not isinstance(value, str):
continue
prop_schema = properties.get(key)
if not prop_schema:
continue
expected = prop_schema.get("type")
if not expected:
continue
coerced = _coerce_value(value, expected)
if coerced is not value:
args[key] = coerced
return args
def _coerce_value(value: str, expected_type):
"""Attempt to coerce a string *value* to *expected_type*.
Returns the original string when coercion is not applicable or fails.
"""
if isinstance(expected_type, list):
# Union type — try each in order, return first successful coercion
for t in expected_type:
result = _coerce_value(value, t)
if result is not value:
return result
return value
if expected_type in ("integer", "number"):
return _coerce_number(value, integer_only=(expected_type == "integer"))
if expected_type == "boolean":
return _coerce_boolean(value)
return value
def _coerce_number(value: str, integer_only: bool = False):
"""Try to parse *value* as a number. Returns original string on failure."""
try:
f = float(value)
except (ValueError, OverflowError):
return value
# Guard against inf/nan before int() conversion
if f != f or f == float("inf") or f == float("-inf"):
return f
# If it looks like an integer (no fractional part), return int
if f == int(f):
return int(f)
if integer_only:
# Schema wants an integer but value has decimals — keep as string
return value
return f
def _coerce_boolean(value: str):
"""Try to parse *value* as a boolean. Returns original string on failure."""
low = value.strip().lower()
if low == "true":
return True
if low == "false":
return False
return value
def handle_function_call(
function_name: str,
function_args: Dict[str, Any],
@@ -388,6 +479,9 @@ def handle_function_call(
Returns:
Function result as a JSON string.
"""
# Coerce string arguments to their schema-declared types (e.g. "42"→42)
function_args = coerce_tool_args(function_name, function_args)
# Notify the read-loop tracker when a non-read/search tool runs,
# so the *consecutive* counter resets (reads after other work are fine).
if function_name not in _READ_SEARCH_TOOLS:

View File

@@ -0,0 +1,262 @@
"""Tests for tool argument type coercion.
When LLMs return tool call arguments, they frequently put numbers as strings
("42" instead of 42) and booleans as strings ("true" instead of true).
coerce_tool_args() fixes these type mismatches by comparing argument values
against the tool's JSON Schema before dispatch.
"""
import pytest
from unittest.mock import patch
from model_tools import (
coerce_tool_args,
_coerce_value,
_coerce_number,
_coerce_boolean,
)
# ── Low-level coercion helpers ────────────────────────────────────────────
class TestCoerceNumber:
"""Unit tests for _coerce_number."""
def test_integer_string(self):
assert _coerce_number("42") == 42
assert isinstance(_coerce_number("42"), int)
def test_negative_integer(self):
assert _coerce_number("-7") == -7
def test_zero(self):
assert _coerce_number("0") == 0
assert isinstance(_coerce_number("0"), int)
def test_float_string(self):
assert _coerce_number("3.14") == 3.14
assert isinstance(_coerce_number("3.14"), float)
def test_float_with_zero_fractional(self):
"""3.0 should become int(3) since there's no fractional part."""
assert _coerce_number("3.0") == 3
assert isinstance(_coerce_number("3.0"), int)
def test_integer_only_rejects_float(self):
"""When integer_only=True, "3.14" should stay as string."""
result = _coerce_number("3.14", integer_only=True)
assert result == "3.14"
assert isinstance(result, str)
def test_integer_only_accepts_whole(self):
assert _coerce_number("42", integer_only=True) == 42
def test_not_a_number(self):
assert _coerce_number("hello") == "hello"
def test_empty_string(self):
assert _coerce_number("") == ""
def test_large_number(self):
assert _coerce_number("1000000") == 1000000
def test_scientific_notation(self):
assert _coerce_number("1e5") == 100000
def test_inf_stays_string_for_integer_only(self):
"""Infinity should not be converted to int."""
result = _coerce_number("inf")
assert result == float("inf")
def test_negative_float(self):
assert _coerce_number("-2.5") == -2.5
class TestCoerceBoolean:
"""Unit tests for _coerce_boolean."""
def test_true_lowercase(self):
assert _coerce_boolean("true") is True
def test_false_lowercase(self):
assert _coerce_boolean("false") is False
def test_true_mixed_case(self):
assert _coerce_boolean("True") is True
def test_false_mixed_case(self):
assert _coerce_boolean("False") is False
def test_true_with_whitespace(self):
assert _coerce_boolean(" true ") is True
def test_not_a_boolean(self):
assert _coerce_boolean("yes") == "yes"
def test_one_zero_not_coerced(self):
"""'1' and '0' are not boolean values."""
assert _coerce_boolean("1") == "1"
assert _coerce_boolean("0") == "0"
def test_empty_string(self):
assert _coerce_boolean("") == ""
class TestCoerceValue:
"""Unit tests for _coerce_value."""
def test_integer_type(self):
assert _coerce_value("5", "integer") == 5
def test_number_type(self):
assert _coerce_value("3.14", "number") == 3.14
def test_boolean_type(self):
assert _coerce_value("true", "boolean") is True
def test_string_type_passthrough(self):
"""Strings expected as strings should not be coerced."""
assert _coerce_value("hello", "string") == "hello"
def test_unknown_type_passthrough(self):
assert _coerce_value("stuff", "object") == "stuff"
def test_union_type_prefers_first_match(self):
"""Union types try each in order."""
assert _coerce_value("42", ["integer", "string"]) == 42
def test_union_type_falls_through(self):
"""If no type matches, return original string."""
assert _coerce_value("hello", ["integer", "boolean"]) == "hello"
def test_union_with_string_preserves_original(self):
"""A non-numeric string in [number, string] should stay a string."""
assert _coerce_value("hello", ["number", "string"]) == "hello"
# ── Full coerce_tool_args with registry ───────────────────────────────────
class TestCoerceToolArgs:
"""Integration tests for coerce_tool_args using the tool registry."""
def _mock_schema(self, properties):
"""Build a minimal tool schema with the given properties."""
return {
"name": "test_tool",
"description": "test",
"parameters": {
"type": "object",
"properties": properties,
},
}
def test_coerces_integer_arg(self):
schema = self._mock_schema({"limit": {"type": "integer"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"limit": "10"}
result = coerce_tool_args("test_tool", args)
assert result["limit"] == 10
assert isinstance(result["limit"], int)
def test_coerces_boolean_arg(self):
schema = self._mock_schema({"merge": {"type": "boolean"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"merge": "true"}
result = coerce_tool_args("test_tool", args)
assert result["merge"] is True
def test_coerces_number_arg(self):
schema = self._mock_schema({"temperature": {"type": "number"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"temperature": "0.7"}
result = coerce_tool_args("test_tool", args)
assert result["temperature"] == 0.7
def test_leaves_string_args_alone(self):
schema = self._mock_schema({"path": {"type": "string"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"path": "/tmp/file.txt"}
result = coerce_tool_args("test_tool", args)
assert result["path"] == "/tmp/file.txt"
def test_leaves_already_correct_types(self):
schema = self._mock_schema({"limit": {"type": "integer"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"limit": 10}
result = coerce_tool_args("test_tool", args)
assert result["limit"] == 10
def test_unknown_tool_returns_args_unchanged(self):
with patch("model_tools.registry.get_schema", return_value=None):
args = {"limit": "10"}
result = coerce_tool_args("unknown_tool", args)
assert result["limit"] == "10"
def test_empty_args(self):
assert coerce_tool_args("test_tool", {}) == {}
def test_none_args(self):
assert coerce_tool_args("test_tool", None) is None
def test_preserves_non_string_values(self):
"""Lists, dicts, and other non-string values are never touched."""
schema = self._mock_schema({
"items": {"type": "array"},
"config": {"type": "object"},
})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"items": [1, 2, 3], "config": {"key": "val"}}
result = coerce_tool_args("test_tool", args)
assert result["items"] == [1, 2, 3]
assert result["config"] == {"key": "val"}
def test_extra_args_without_schema_left_alone(self):
"""Args not in the schema properties are not touched."""
schema = self._mock_schema({"limit": {"type": "integer"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"limit": "10", "extra": "42"}
result = coerce_tool_args("test_tool", args)
assert result["limit"] == 10
assert result["extra"] == "42" # no schema for extra, stays string
def test_mixed_coercion(self):
"""Multiple args coerced in the same call."""
schema = self._mock_schema({
"offset": {"type": "integer"},
"limit": {"type": "integer"},
"full": {"type": "boolean"},
"path": {"type": "string"},
})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {
"offset": "1",
"limit": "500",
"full": "false",
"path": "readme.md",
}
result = coerce_tool_args("test_tool", args)
assert result["offset"] == 1
assert result["limit"] == 500
assert result["full"] is False
assert result["path"] == "readme.md"
def test_failed_coercion_preserves_original(self):
"""A non-parseable string stays as string even if schema says integer."""
schema = self._mock_schema({"limit": {"type": "integer"}})
with patch("model_tools.registry.get_schema", return_value=schema):
args = {"limit": "not_a_number"}
result = coerce_tool_args("test_tool", args)
assert result["limit"] == "not_a_number"
def test_real_read_file_schema(self):
"""Test against the actual read_file schema from the registry."""
# This uses the real registry — read_file should be registered
args = {"path": "foo.py", "offset": "10", "limit": "100"}
result = coerce_tool_args("read_file", args)
assert result["path"] == "foo.py"
assert result["offset"] == 10
assert isinstance(result["offset"], int)
assert result["limit"] == 100
assert isinstance(result["limit"], int)