diff --git a/model_tools.py b/model_tools.py index ec472ff99..edea2315d 100644 --- a/model_tools.py +++ b/model_tools.py @@ -365,6 +365,97 @@ _AGENT_LOOP_TOOLS = {"todo", "memory", "session_search", "delegate_task"} _READ_SEARCH_TOOLS = {"read_file", "search_files"} +# ========================================================================= +# Tool argument type coercion +# ========================================================================= + +def coerce_tool_args(tool_name: str, args: Dict[str, Any]) -> Dict[str, Any]: + """Coerce tool call arguments to match their JSON Schema types. + + LLMs frequently return numbers as strings (``"42"`` instead of ``42``) + and booleans as strings (``"true"`` instead of ``true``). This compares + each argument value against the tool's registered JSON Schema and attempts + safe coercion when the value is a string but the schema expects a different + type. Original values are preserved when coercion fails. + + Handles ``"type": "integer"``, ``"type": "number"``, ``"type": "boolean"``, + and union types (``"type": ["integer", "string"]``). + """ + if not args or not isinstance(args, dict): + return args + + schema = registry.get_schema(tool_name) + if not schema: + return args + + properties = (schema.get("parameters") or {}).get("properties") + if not properties: + return args + + for key, value in args.items(): + if not isinstance(value, str): + continue + prop_schema = properties.get(key) + if not prop_schema: + continue + expected = prop_schema.get("type") + if not expected: + continue + coerced = _coerce_value(value, expected) + if coerced is not value: + args[key] = coerced + + return args + + +def _coerce_value(value: str, expected_type): + """Attempt to coerce a string *value* to *expected_type*. + + Returns the original string when coercion is not applicable or fails. + """ + if isinstance(expected_type, list): + # Union type — try each in order, return first successful coercion + for t in expected_type: + result = _coerce_value(value, t) + if result is not value: + return result + return value + + if expected_type in ("integer", "number"): + return _coerce_number(value, integer_only=(expected_type == "integer")) + if expected_type == "boolean": + return _coerce_boolean(value) + return value + + +def _coerce_number(value: str, integer_only: bool = False): + """Try to parse *value* as a number. Returns original string on failure.""" + try: + f = float(value) + except (ValueError, OverflowError): + return value + # Guard against inf/nan before int() conversion + if f != f or f == float("inf") or f == float("-inf"): + return f + # If it looks like an integer (no fractional part), return int + if f == int(f): + return int(f) + if integer_only: + # Schema wants an integer but value has decimals — keep as string + return value + return f + + +def _coerce_boolean(value: str): + """Try to parse *value* as a boolean. Returns original string on failure.""" + low = value.strip().lower() + if low == "true": + return True + if low == "false": + return False + return value + + def handle_function_call( function_name: str, function_args: Dict[str, Any], @@ -388,6 +479,9 @@ def handle_function_call( Returns: Function result as a JSON string. """ + # Coerce string arguments to their schema-declared types (e.g. "42"→42) + function_args = coerce_tool_args(function_name, function_args) + # Notify the read-loop tracker when a non-read/search tool runs, # so the *consecutive* counter resets (reads after other work are fine). if function_name not in _READ_SEARCH_TOOLS: diff --git a/tests/test_tool_arg_coercion.py b/tests/test_tool_arg_coercion.py new file mode 100644 index 000000000..cf1876d4e --- /dev/null +++ b/tests/test_tool_arg_coercion.py @@ -0,0 +1,262 @@ +"""Tests for tool argument type coercion. + +When LLMs return tool call arguments, they frequently put numbers as strings +("42" instead of 42) and booleans as strings ("true" instead of true). +coerce_tool_args() fixes these type mismatches by comparing argument values +against the tool's JSON Schema before dispatch. +""" + +import pytest +from unittest.mock import patch + +from model_tools import ( + coerce_tool_args, + _coerce_value, + _coerce_number, + _coerce_boolean, +) + + +# ── Low-level coercion helpers ──────────────────────────────────────────── + + +class TestCoerceNumber: + """Unit tests for _coerce_number.""" + + def test_integer_string(self): + assert _coerce_number("42") == 42 + assert isinstance(_coerce_number("42"), int) + + def test_negative_integer(self): + assert _coerce_number("-7") == -7 + + def test_zero(self): + assert _coerce_number("0") == 0 + assert isinstance(_coerce_number("0"), int) + + def test_float_string(self): + assert _coerce_number("3.14") == 3.14 + assert isinstance(_coerce_number("3.14"), float) + + def test_float_with_zero_fractional(self): + """3.0 should become int(3) since there's no fractional part.""" + assert _coerce_number("3.0") == 3 + assert isinstance(_coerce_number("3.0"), int) + + def test_integer_only_rejects_float(self): + """When integer_only=True, "3.14" should stay as string.""" + result = _coerce_number("3.14", integer_only=True) + assert result == "3.14" + assert isinstance(result, str) + + def test_integer_only_accepts_whole(self): + assert _coerce_number("42", integer_only=True) == 42 + + def test_not_a_number(self): + assert _coerce_number("hello") == "hello" + + def test_empty_string(self): + assert _coerce_number("") == "" + + def test_large_number(self): + assert _coerce_number("1000000") == 1000000 + + def test_scientific_notation(self): + assert _coerce_number("1e5") == 100000 + + def test_inf_stays_string_for_integer_only(self): + """Infinity should not be converted to int.""" + result = _coerce_number("inf") + assert result == float("inf") + + def test_negative_float(self): + assert _coerce_number("-2.5") == -2.5 + + +class TestCoerceBoolean: + """Unit tests for _coerce_boolean.""" + + def test_true_lowercase(self): + assert _coerce_boolean("true") is True + + def test_false_lowercase(self): + assert _coerce_boolean("false") is False + + def test_true_mixed_case(self): + assert _coerce_boolean("True") is True + + def test_false_mixed_case(self): + assert _coerce_boolean("False") is False + + def test_true_with_whitespace(self): + assert _coerce_boolean(" true ") is True + + def test_not_a_boolean(self): + assert _coerce_boolean("yes") == "yes" + + def test_one_zero_not_coerced(self): + """'1' and '0' are not boolean values.""" + assert _coerce_boolean("1") == "1" + assert _coerce_boolean("0") == "0" + + def test_empty_string(self): + assert _coerce_boolean("") == "" + + +class TestCoerceValue: + """Unit tests for _coerce_value.""" + + def test_integer_type(self): + assert _coerce_value("5", "integer") == 5 + + def test_number_type(self): + assert _coerce_value("3.14", "number") == 3.14 + + def test_boolean_type(self): + assert _coerce_value("true", "boolean") is True + + def test_string_type_passthrough(self): + """Strings expected as strings should not be coerced.""" + assert _coerce_value("hello", "string") == "hello" + + def test_unknown_type_passthrough(self): + assert _coerce_value("stuff", "object") == "stuff" + + def test_union_type_prefers_first_match(self): + """Union types try each in order.""" + assert _coerce_value("42", ["integer", "string"]) == 42 + + def test_union_type_falls_through(self): + """If no type matches, return original string.""" + assert _coerce_value("hello", ["integer", "boolean"]) == "hello" + + def test_union_with_string_preserves_original(self): + """A non-numeric string in [number, string] should stay a string.""" + assert _coerce_value("hello", ["number", "string"]) == "hello" + + +# ── Full coerce_tool_args with registry ─────────────────────────────────── + + +class TestCoerceToolArgs: + """Integration tests for coerce_tool_args using the tool registry.""" + + def _mock_schema(self, properties): + """Build a minimal tool schema with the given properties.""" + return { + "name": "test_tool", + "description": "test", + "parameters": { + "type": "object", + "properties": properties, + }, + } + + def test_coerces_integer_arg(self): + schema = self._mock_schema({"limit": {"type": "integer"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"limit": "10"} + result = coerce_tool_args("test_tool", args) + assert result["limit"] == 10 + assert isinstance(result["limit"], int) + + def test_coerces_boolean_arg(self): + schema = self._mock_schema({"merge": {"type": "boolean"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"merge": "true"} + result = coerce_tool_args("test_tool", args) + assert result["merge"] is True + + def test_coerces_number_arg(self): + schema = self._mock_schema({"temperature": {"type": "number"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"temperature": "0.7"} + result = coerce_tool_args("test_tool", args) + assert result["temperature"] == 0.7 + + def test_leaves_string_args_alone(self): + schema = self._mock_schema({"path": {"type": "string"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"path": "/tmp/file.txt"} + result = coerce_tool_args("test_tool", args) + assert result["path"] == "/tmp/file.txt" + + def test_leaves_already_correct_types(self): + schema = self._mock_schema({"limit": {"type": "integer"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"limit": 10} + result = coerce_tool_args("test_tool", args) + assert result["limit"] == 10 + + def test_unknown_tool_returns_args_unchanged(self): + with patch("model_tools.registry.get_schema", return_value=None): + args = {"limit": "10"} + result = coerce_tool_args("unknown_tool", args) + assert result["limit"] == "10" + + def test_empty_args(self): + assert coerce_tool_args("test_tool", {}) == {} + + def test_none_args(self): + assert coerce_tool_args("test_tool", None) is None + + def test_preserves_non_string_values(self): + """Lists, dicts, and other non-string values are never touched.""" + schema = self._mock_schema({ + "items": {"type": "array"}, + "config": {"type": "object"}, + }) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"items": [1, 2, 3], "config": {"key": "val"}} + result = coerce_tool_args("test_tool", args) + assert result["items"] == [1, 2, 3] + assert result["config"] == {"key": "val"} + + def test_extra_args_without_schema_left_alone(self): + """Args not in the schema properties are not touched.""" + schema = self._mock_schema({"limit": {"type": "integer"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"limit": "10", "extra": "42"} + result = coerce_tool_args("test_tool", args) + assert result["limit"] == 10 + assert result["extra"] == "42" # no schema for extra, stays string + + def test_mixed_coercion(self): + """Multiple args coerced in the same call.""" + schema = self._mock_schema({ + "offset": {"type": "integer"}, + "limit": {"type": "integer"}, + "full": {"type": "boolean"}, + "path": {"type": "string"}, + }) + with patch("model_tools.registry.get_schema", return_value=schema): + args = { + "offset": "1", + "limit": "500", + "full": "false", + "path": "readme.md", + } + result = coerce_tool_args("test_tool", args) + assert result["offset"] == 1 + assert result["limit"] == 500 + assert result["full"] is False + assert result["path"] == "readme.md" + + def test_failed_coercion_preserves_original(self): + """A non-parseable string stays as string even if schema says integer.""" + schema = self._mock_schema({"limit": {"type": "integer"}}) + with patch("model_tools.registry.get_schema", return_value=schema): + args = {"limit": "not_a_number"} + result = coerce_tool_args("test_tool", args) + assert result["limit"] == "not_a_number" + + def test_real_read_file_schema(self): + """Test against the actual read_file schema from the registry.""" + # This uses the real registry — read_file should be registered + args = {"path": "foo.py", "offset": "10", "limit": "100"} + result = coerce_tool_args("read_file", args) + assert result["path"] == "foo.py" + assert result["offset"] == 10 + assert isinstance(result["offset"], int) + assert result["limit"] == 100 + assert isinstance(result["limit"], int)