1
0

fix: replace eval() with AST-walking safe evaluator in calculator

Fixes #52

- Replace eval() in calculator() with _safe_eval() that walks the AST
  and only permits: numeric constants, arithmetic ops (+,-,*,/,//,%,**),
  unary +/-, math module access, and whitelisted builtins (abs, round,
  min, max)
- Reject all other syntax: imports, attribute access on non-math objects,
  lambdas, comprehensions, string literals, etc.
- Add 39 tests covering arithmetic, precedence, math functions,
  allowed builtins, error handling, and 14 injection prevention cases
This commit is contained in:
2026-03-14 15:51:35 -04:00
parent 122d07471e
commit 70d5dc5ce1
2 changed files with 226 additions and 3 deletions

View File

@@ -13,6 +13,7 @@ Tools are assigned to agents based on their specialties.
from __future__ import annotations
import ast
import logging
import math
from collections.abc import Callable
@@ -115,6 +116,59 @@ def get_tool_stats(agent_id: str | None = None) -> dict:
return all_stats
def _safe_eval(node, allowed_names: dict):
"""Walk an AST and evaluate only safe numeric operations."""
if isinstance(node, ast.Expression):
return _safe_eval(node.body, allowed_names)
if isinstance(node, ast.Constant):
if isinstance(node.value, (int, float, complex)):
return node.value
raise ValueError(f"Unsupported constant: {node.value!r}")
if isinstance(node, ast.UnaryOp):
operand = _safe_eval(node.operand, allowed_names)
if isinstance(node.op, ast.UAdd):
return +operand
if isinstance(node.op, ast.USub):
return -operand
raise ValueError(f"Unsupported unary op: {type(node.op).__name__}")
if isinstance(node, ast.BinOp):
left = _safe_eval(node.left, allowed_names)
right = _safe_eval(node.right, allowed_names)
ops = {
ast.Add: lambda a, b: a + b,
ast.Sub: lambda a, b: a - b,
ast.Mult: lambda a, b: a * b,
ast.Div: lambda a, b: a / b,
ast.FloorDiv: lambda a, b: a // b,
ast.Mod: lambda a, b: a % b,
ast.Pow: lambda a, b: a**b,
}
op_fn = ops.get(type(node.op))
if op_fn is None:
raise ValueError(f"Unsupported binary op: {type(node.op).__name__}")
return op_fn(left, right)
if isinstance(node, ast.Name):
if node.id in allowed_names:
return allowed_names[node.id]
raise ValueError(f"Unknown name: {node.id!r}")
if isinstance(node, ast.Attribute):
value = _safe_eval(node.value, allowed_names)
# Only allow attribute access on the math module
if value is math:
attr = getattr(math, node.attr, None)
if attr is not None:
return attr
raise ValueError(f"Attribute access not allowed: .{node.attr}")
if isinstance(node, ast.Call):
func = _safe_eval(node.func, allowed_names)
if not callable(func):
raise ValueError(f"Not callable: {func!r}")
args = [_safe_eval(a, allowed_names) for a in node.args]
kwargs = {kw.arg: _safe_eval(kw.value, allowed_names) for kw in node.keywords}
return func(*args, **kwargs)
raise ValueError(f"Unsupported syntax: {type(node).__name__}")
def calculator(expression: str) -> str:
"""Evaluate a mathematical expression and return the exact result.
@@ -128,15 +182,15 @@ def calculator(expression: str) -> str:
Returns:
The exact result as a string.
"""
# Only expose math functions — no builtins, no file/os access
allowed_names = {k: getattr(math, k) for k in dir(math) if not k.startswith("_")}
allowed_names["math"] = math # Support math.sqrt(), math.pi, etc.
allowed_names["math"] = math
allowed_names["abs"] = abs
allowed_names["round"] = round
allowed_names["min"] = min
allowed_names["max"] = max
try:
result = eval(expression, {"__builtins__": {}}, allowed_names) # noqa: S307
tree = ast.parse(expression, mode="eval")
result = _safe_eval(tree, allowed_names)
return str(result)
except Exception as e:
return f"Error evaluating '{expression}': {e}"