Fixes #52 - Replace eval() in calculator() with _safe_eval() that walks the AST and only permits: numeric constants, arithmetic ops (+,-,*,/,//,%,**), unary +/-, math module access, and whitelisted builtins (abs, round, min, max) - Reject all other syntax: imports, attribute access on non-math objects, lambdas, comprehensions, string literals, etc. - Add 39 tests covering arithmetic, precedence, math functions, allowed builtins, error handling, and 14 injection prevention cases
170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
"""Tests for the safe calculator tool (issue #52)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
|
|
from timmy.tools import calculator
|
|
|
|
# ── Basic arithmetic ──────────────────────────────────────────────
|
|
|
|
|
|
class TestBasicArithmetic:
|
|
def test_addition(self):
|
|
assert calculator("2 + 3") == "5"
|
|
|
|
def test_subtraction(self):
|
|
assert calculator("10 - 4") == "6"
|
|
|
|
def test_multiplication(self):
|
|
assert calculator("347 * 829") == str(347 * 829)
|
|
|
|
def test_division(self):
|
|
assert calculator("10 / 3") == str(10 / 3)
|
|
|
|
def test_floor_division(self):
|
|
assert calculator("10 // 3") == "3"
|
|
|
|
def test_modulo(self):
|
|
assert calculator("10 % 3") == "1"
|
|
|
|
def test_exponent(self):
|
|
assert calculator("2**10") == "1024"
|
|
|
|
def test_negative_number(self):
|
|
assert calculator("-5 + 3") == "-2"
|
|
|
|
def test_unary_plus(self):
|
|
assert calculator("+5") == "5"
|
|
|
|
|
|
# ── Parentheses and precedence ────────────────────────────────────
|
|
|
|
|
|
class TestPrecedence:
|
|
def test_nested_parens(self):
|
|
assert calculator("(2 + 3) * (4 + 1)") == "25"
|
|
|
|
def test_deep_nesting(self):
|
|
assert calculator("((1 + 2) * (3 + 4)) + 5") == "26"
|
|
|
|
def test_operator_precedence(self):
|
|
assert calculator("2 + 3 * 4") == "14"
|
|
|
|
|
|
# ── Math module functions ─────────────────────────────────────────
|
|
|
|
|
|
class TestMathFunctions:
|
|
def test_sqrt(self):
|
|
assert calculator("math.sqrt(144)") == "12.0"
|
|
|
|
def test_log(self):
|
|
assert calculator("math.log(100, 10)") == str(math.log(100, 10))
|
|
|
|
def test_sin(self):
|
|
assert calculator("math.sin(0)") == "0.0"
|
|
|
|
def test_pi(self):
|
|
assert calculator("math.pi") == str(math.pi)
|
|
|
|
def test_e(self):
|
|
assert calculator("math.e") == str(math.e)
|
|
|
|
def test_ceil(self):
|
|
assert calculator("math.ceil(4.3)") == "5"
|
|
|
|
def test_floor(self):
|
|
assert calculator("math.floor(4.7)") == "4"
|
|
|
|
def test_bare_sqrt(self):
|
|
assert calculator("sqrt(16)") == "4.0"
|
|
|
|
|
|
# ── Allowed builtins ──────────────────────────────────────────────
|
|
|
|
|
|
class TestAllowedBuiltins:
|
|
def test_abs(self):
|
|
assert calculator("abs(-42)") == "42"
|
|
|
|
def test_round(self):
|
|
assert calculator("round(3.14159, 2)") == "3.14"
|
|
|
|
def test_min(self):
|
|
assert calculator("min(3, 1, 2)") == "1"
|
|
|
|
def test_max(self):
|
|
assert calculator("max(3, 1, 2)") == "3"
|
|
|
|
|
|
# ── Error handling ────────────────────────────────────────────────
|
|
|
|
|
|
class TestErrorHandling:
|
|
def test_division_by_zero(self):
|
|
result = calculator("1 / 0")
|
|
assert "Error" in result
|
|
|
|
def test_syntax_error(self):
|
|
result = calculator("2 +")
|
|
assert "Error" in result
|
|
|
|
def test_empty_expression(self):
|
|
result = calculator("")
|
|
assert "Error" in result
|
|
|
|
|
|
# ── Injection attempts (the whole point of issue #52) ─────────────
|
|
|
|
|
|
class TestInjectionPrevention:
|
|
def test_import_os(self):
|
|
result = calculator("__import__('os').system('echo hacked')")
|
|
assert "Error" in result
|
|
assert "Unknown name" in result or "Unsupported" in result
|
|
|
|
def test_builtins_access(self):
|
|
result = calculator("__builtins__")
|
|
assert "Error" in result
|
|
|
|
def test_dunder_class(self):
|
|
result = calculator("().__class__.__bases__[0].__subclasses__()")
|
|
assert "Error" in result
|
|
|
|
def test_exec(self):
|
|
result = calculator("exec('import os')")
|
|
assert "Error" in result
|
|
|
|
def test_eval_nested(self):
|
|
result = calculator("eval('1+1')")
|
|
assert "Error" in result
|
|
|
|
def test_open_file(self):
|
|
result = calculator("open('/etc/passwd').read()")
|
|
assert "Error" in result
|
|
|
|
def test_string_literal_rejected(self):
|
|
result = calculator("'hello'")
|
|
assert "Error" in result
|
|
|
|
def test_list_comprehension(self):
|
|
result = calculator("[x for x in range(10)]")
|
|
assert "Error" in result
|
|
|
|
def test_lambda(self):
|
|
result = calculator("(lambda: 1)()")
|
|
assert "Error" in result
|
|
|
|
def test_attribute_on_non_math(self):
|
|
result = calculator("(1).__class__")
|
|
assert "Error" in result
|
|
|
|
def test_globals(self):
|
|
result = calculator("globals()")
|
|
assert "Error" in result
|
|
|
|
def test_breakout_via_format(self):
|
|
result = calculator("'{}'.format.__globals__")
|
|
assert "Error" in result
|