Files
hermes-agent/tests/tools/test_code_execution_tool.py

221 lines
6.4 KiB
Python
Raw Normal View History

"""Tests for tools/code_execution_tool.py - Security-critical module.
This module executes arbitrary code and requires comprehensive security testing.
"""
import pytest
from unittest.mock import patch, MagicMock
from types import SimpleNamespace
# Import will fail if module doesn't exist - that's expected
try:
from tools.code_execution_tool import (
execute_code,
validate_code_safety,
CodeExecutionError,
ResourceLimitExceeded,
)
HAS_MODULE = True
except ImportError:
HAS_MODULE = False
pytestmark = [
pytest.mark.skipif(not HAS_MODULE, reason="code_execution_tool module not found"),
pytest.mark.security, # Mark as security test
]
class TestValidateCodeSafety:
"""Tests for code safety validation."""
def test_blocks_dangerous_imports(self):
"""Should block imports of dangerous modules."""
dangerous_code = """
import os
os.system('rm -rf /')
"""
with pytest.raises(CodeExecutionError) as exc_info:
validate_code_safety(dangerous_code)
assert "dangerous import" in str(exc_info.value).lower()
def test_blocks_subprocess(self):
"""Should block subprocess module usage."""
code = """
import subprocess
subprocess.run(['ls', '-la'])
"""
with pytest.raises(CodeExecutionError):
validate_code_safety(code)
def test_blocks_compile_eval(self):
"""Should block compile() and eval() usage."""
code = "eval('__import__(\"os\").system(\"ls\")')"
with pytest.raises(CodeExecutionError):
validate_code_safety(code)
def test_blocks_file_operations(self):
"""Should block direct file operations."""
code = """
with open('/etc/passwd', 'r') as f:
data = f.read()
"""
with pytest.raises(CodeExecutionError):
validate_code_safety(code)
def test_allows_safe_code(self):
"""Should allow safe code execution."""
safe_code = """
def factorial(n):
if n <= 1:
return 1
return n * factorial(n - 1)
result = factorial(5)
"""
# Should not raise
validate_code_safety(safe_code)
def test_blocks_network_access(self):
"""Should block network-related imports."""
code = """
import socket
s = socket.socket()
"""
with pytest.raises(CodeExecutionError):
validate_code_safety(code)
class TestExecuteCode:
"""Tests for code execution with sandboxing."""
def test_executes_simple_code(self):
"""Should execute simple code and return result."""
code = "result = 2 + 2"
result = execute_code(code)
assert result["success"] is True
assert result.get("variables", {}).get("result") == 4
def test_handles_syntax_errors(self):
"""Should gracefully handle syntax errors."""
code = "def broken("
result = execute_code(code)
assert result["success"] is False
assert "syntax" in result.get("error", "").lower()
def test_handles_runtime_errors(self):
"""Should gracefully handle runtime errors."""
code = "1 / 0"
result = execute_code(code)
assert result["success"] is False
assert "zero" in result.get("error", "").lower()
def test_enforces_timeout(self):
"""Should enforce execution timeout."""
code = """
import time
time.sleep(100) # Long sleep
"""
with pytest.raises(ResourceLimitExceeded):
execute_code(code, timeout=1)
def test_enforces_memory_limit(self):
"""Should enforce memory usage limit."""
code = """
# Try to allocate large amount of memory
huge_list = [0] * (100 * 1024 * 1024) # 100M integers
"""
with pytest.raises(ResourceLimitExceeded):
execute_code(code, memory_limit_mb=10)
def test_restricts_available_modules(self):
"""Should only allow whitelisted modules."""
code = """
import math
result = math.sqrt(16)
"""
result = execute_code(code, allowed_modules=["math"])
assert result["success"] is True
def test_captures_stdout(self):
"""Should capture stdout from executed code."""
code = """
print("Hello, World!")
print("Second line")
"""
result = execute_code(code)
assert result["success"] is True
assert "Hello, World!" in result.get("stdout", "")
assert "Second line" in result.get("stdout", "")
def test_captures_stderr(self):
"""Should capture stderr from executed code."""
code = """
import sys
print("Error message", file=sys.stderr)
"""
result = execute_code(code)
assert "Error message" in result.get("stderr", "")
class TestResourceLimits:
"""Tests for resource limit enforcement."""
def test_cpu_time_limit(self):
"""Should limit CPU time usage."""
code = """
# CPU-intensive calculation
for i in range(10000000):
pass
"""
with pytest.raises(ResourceLimitExceeded):
execute_code(code, cpu_time_limit=0.1)
def test_output_size_limit(self):
"""Should limit output size."""
code = """
# Generate large output
print("x" * (10 * 1024 * 1024)) # 10MB of output
"""
with pytest.raises(ResourceLimitExceeded):
execute_code(code, max_output_size=1024)
class TestSecurityScenarios:
"""Security-focused test scenarios."""
def test_prevents_shell_injection(self):
"""Should prevent shell command injection."""
malicious_code = """
__import__('os').system('cat /etc/passwd')
"""
with pytest.raises(CodeExecutionError):
validate_code_safety(malicious_code)
def test_prevents_import_builtins_abuse(self):
"""Should prevent __builtins__ abuse."""
code = """
__builtins__['__import__']('os').system('ls')
"""
with pytest.raises(CodeExecutionError):
validate_code_safety(code)
def test_isolates_globals(self):
"""Should isolate global namespace between executions."""
code1 = "x = 42"
execute_code(code1)
code2 = "result = x + 1" # Should not have access to x
result = execute_code(code2)
assert result["success"] is False # NameError expected
def test_prevents_infinite_recursion(self):
"""Should prevent/recover from infinite recursion."""
code = """
def recurse():
return recurse()
recurse()
"""
with pytest.raises(ResourceLimitExceeded):
execute_code(code, max_recursion_depth=100)