diff --git a/tests/tools/test_syntax_preflight.py b/tests/tools/test_syntax_preflight.py new file mode 100644 index 000000000..488c36976 --- /dev/null +++ b/tests/tools/test_syntax_preflight.py @@ -0,0 +1,107 @@ +"""Tests for syntax preflight check in execute_code (issue #312).""" + +import ast +import json +import pytest + + +class TestSyntaxPreflight: + """Verify that execute_code catches syntax errors before sandbox execution.""" + + def test_valid_syntax_passes_parse(self): + """Valid Python should pass ast.parse.""" + code = "print('hello')\nx = 1 + 2\n" + ast.parse(code) # should not raise + + def test_syntax_error_indentation(self): + """IndentationError is a subclass of SyntaxError.""" + code = "def foo():\nbar()\n" + with pytest.raises(SyntaxError): + ast.parse(code) + + def test_syntax_error_missing_colon(self): + code = "if True\n pass\n" + with pytest.raises(SyntaxError): + ast.parse(code) + + def test_syntax_error_unmatched_paren(self): + code = "x = (1 + 2\n" + with pytest.raises(SyntaxError): + ast.parse(code) + + def test_syntax_error_invalid_token(self): + code = "x = 1 +*\n" + with pytest.raises(SyntaxError): + ast.parse(code) + + def test_syntax_error_details(self): + """SyntaxError should provide line, offset, msg.""" + code = "if True\n pass\n" + with pytest.raises(SyntaxError) as exc_info: + ast.parse(code) + e = exc_info.value + assert e.lineno is not None + assert e.msg is not None + + def test_empty_string_passes(self): + """Empty string is valid Python (empty module).""" + ast.parse("") + + def test_comments_only_passes(self): + ast.parse("# just a comment\n# another\n") + + def test_complex_valid_code(self): + code = ''' +import os +def foo(x): + if x > 0: + return x * 2 + return 0 + +result = [foo(i) for i in range(10)] +print(result) +''' + ast.parse(code) + + +class TestSyntaxPreflightResponse: + """Test the error response format from the preflight check.""" + + def _check_syntax(self, code): + """Mimic the preflight check logic from execute_code.""" + try: + ast.parse(code) + return None + except SyntaxError as e: + return json.dumps({ + "error": f"Python syntax error: {e.msg}", + "line": e.lineno, + "offset": e.offset, + "text": (e.text or "").strip()[:200], + }) + + def test_returns_json_error(self): + result = self._check_syntax("if True\n pass\n") + assert result is not None + data = json.loads(result) + assert "error" in data + assert "syntax error" in data["error"].lower() + + def test_includes_line_number(self): + result = self._check_syntax("x = 1\nif True\n pass\n") + data = json.loads(result) + assert data["line"] == 2 # error on line 2 + + def test_includes_offset(self): + result = self._check_syntax("x = (1 + 2\n") + data = json.loads(result) + assert data["offset"] is not None + + def test_includes_snippet(self): + result = self._check_syntax("if True\n") + data = json.loads(result) + assert "if True" in data["text"] + + def test_none_for_valid_code(self): + result = self._check_syntax("print('ok')") + assert result is None diff --git a/tools/code_execution_tool.py b/tools/code_execution_tool.py index 8dd6c759e..aa17e58a1 100644 --- a/tools/code_execution_tool.py +++ b/tools/code_execution_tool.py @@ -28,6 +28,7 @@ Platform: Linux / macOS only (Unix domain sockets for local). Disabled on Window Remote execution additionally requires Python 3 in the terminal backend. """ +import ast import base64 import json import logging @@ -893,6 +894,20 @@ def execute_code( if not code or not code.strip(): return json.dumps({"error": "No code provided."}) + # Poka-yoke (#312): Syntax check before execution. + # 83.2% of execute_code errors are Python exceptions; most are syntax + # errors the LLM generated. ast.parse() is sub-millisecond and catches + # them before we spin up a sandbox child process. + try: + ast.parse(code) + except SyntaxError as e: + return json.dumps({ + "error": f"Python syntax error: {e.msg}", + "line": e.lineno, + "offset": e.offset, + "text": (e.text or "").strip()[:200], + }) + # Dispatch: remote backends use file-based RPC, local uses UDS from tools.terminal_tool import _get_env_config env_type = _get_env_config()["env_type"]