Compare commits
8 Commits
security/f
...
security/f
| Author | SHA1 | Date | |
|---|---|---|---|
| ed32487cbe | |||
| 37c5e672b5 | |||
| cfcffd38ab | |||
| 0b49540db3 | |||
| ffa8405cfb | |||
| cc1b9e8054 | |||
| e2e88b271d | |||
| 0e01f3321d |
@@ -292,7 +292,29 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
extra = config.extra or {}
|
extra = config.extra or {}
|
||||||
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
||||||
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
||||||
|
|
||||||
|
# SECURITY FIX (V-009): Fail-secure default for API key
|
||||||
|
# Previously: Empty API key allowed all requests (dangerous default)
|
||||||
|
# Now: Require explicit "allow_unauthenticated" setting to disable auth
|
||||||
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
||||||
|
self._allow_unauthenticated: bool = extra.get(
|
||||||
|
"allow_unauthenticated",
|
||||||
|
os.getenv("API_SERVER_ALLOW_UNAUTHENTICATED", "").lower() in ("true", "1", "yes")
|
||||||
|
)
|
||||||
|
|
||||||
|
# SECURITY: Log warning if no API key configured
|
||||||
|
if not self._api_key and not self._allow_unauthenticated:
|
||||||
|
logger.warning(
|
||||||
|
"API_SERVER_KEY not configured. All requests will be rejected. "
|
||||||
|
"Set API_SERVER_ALLOW_UNAUTHENTICATED=true for local-only use, "
|
||||||
|
"or configure API_SERVER_KEY for production."
|
||||||
|
)
|
||||||
|
elif not self._api_key and self._allow_unauthenticated:
|
||||||
|
logger.warning(
|
||||||
|
"API_SERVER running without authentication. "
|
||||||
|
"This is only safe for local-only deployments."
|
||||||
|
)
|
||||||
|
|
||||||
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
||||||
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
||||||
)
|
)
|
||||||
@@ -317,15 +339,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
return tuple(str(item).strip() for item in items if str(item).strip())
|
return tuple(str(item).strip() for item in items if str(item).strip())
|
||||||
|
|
||||||
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
||||||
"""Return CORS headers for an allowed browser origin."""
|
"""Return CORS headers for an allowed browser origin.
|
||||||
|
|
||||||
|
SECURITY FIX (V-008): Never allow wildcard "*" with credentials.
|
||||||
|
If "*" is configured, we reject the request to prevent security issues.
|
||||||
|
"""
|
||||||
if not origin or not self._cors_origins:
|
if not origin or not self._cors_origins:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# SECURITY FIX (V-008): Reject wildcard CORS origins
|
||||||
|
# Wildcard with credentials is a security vulnerability
|
||||||
if "*" in self._cors_origins:
|
if "*" in self._cors_origins:
|
||||||
headers = dict(_CORS_HEADERS)
|
logger.warning(
|
||||||
headers["Access-Control-Allow-Origin"] = "*"
|
"CORS wildcard '*' is not allowed for security reasons. "
|
||||||
headers["Access-Control-Max-Age"] = "600"
|
"Please configure specific origins in API_SERVER_CORS_ORIGINS."
|
||||||
return headers
|
)
|
||||||
|
return None # Reject wildcard - too dangerous
|
||||||
|
|
||||||
if origin not in self._cors_origins:
|
if origin not in self._cors_origins:
|
||||||
return None
|
return None
|
||||||
@@ -355,10 +384,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
Validate Bearer token from Authorization header.
|
Validate Bearer token from Authorization header.
|
||||||
|
|
||||||
Returns None if auth is OK, or a 401 web.Response on failure.
|
Returns None if auth is OK, or a 401 web.Response on failure.
|
||||||
If no API key is configured, all requests are allowed.
|
|
||||||
|
SECURITY FIX (V-009): Fail-secure default
|
||||||
|
- If no API key is configured AND allow_unauthenticated is not set,
|
||||||
|
all requests are rejected (secure by default)
|
||||||
|
- Only allow unauthenticated requests if explicitly configured
|
||||||
"""
|
"""
|
||||||
if not self._api_key:
|
# SECURITY: Fail-secure default - reject if no key and not explicitly allowed
|
||||||
return None # No key configured — allow all (local-only use)
|
if not self._api_key and not self._allow_unauthenticated:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": {"message": "Authentication required. Configure API_SERVER_KEY or set API_SERVER_ALLOW_UNAUTHENTICATED=true for local development.", "type": "authentication_error", "code": "auth_required"}},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow unauthenticated requests only if explicitly configured
|
||||||
|
if not self._api_key and self._allow_unauthenticated:
|
||||||
|
return None # Explicitly allowed for local-only use
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.startswith("Bearer "):
|
||||||
|
|||||||
167
hermes_state_patch.py
Normal file
167
hermes_state_patch.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""SQLite State Store patch for cross-process locking.
|
||||||
|
|
||||||
|
Addresses Issue #52: SQLite global write lock causes contention.
|
||||||
|
|
||||||
|
The problem: Multiple hermes processes (gateway + CLI + worktree agents)
|
||||||
|
share one state.db, but each process has its own threading.Lock.
|
||||||
|
This patch adds file-based locking for cross-process coordination.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fcntl
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class CrossProcessLock:
|
||||||
|
"""File-based lock for cross-process SQLite coordination.
|
||||||
|
|
||||||
|
Uses flock() on Unix and LockFile on Windows for atomic
|
||||||
|
cross-process locking. Falls back to threading.Lock if
|
||||||
|
file locking fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock_path: Path):
|
||||||
|
self.lock_path = lock_path
|
||||||
|
self.lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._fd = None
|
||||||
|
self._thread_lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire(self, blocking: bool = True, timeout: float = None) -> bool:
|
||||||
|
"""Acquire the cross-process lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocking: If True, block until lock is acquired
|
||||||
|
timeout: Maximum time to wait (None = forever)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if lock acquired, False if timeout
|
||||||
|
"""
|
||||||
|
with self._thread_lock:
|
||||||
|
if self._fd is not None:
|
||||||
|
return True # Already held
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._fd = open(self.lock_path, "w")
|
||||||
|
if blocking:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_EX)
|
||||||
|
else:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
return True
|
||||||
|
except (IOError, OSError) as e:
|
||||||
|
if self._fd:
|
||||||
|
self._fd.close()
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
if not blocking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if timeout and (time.time() - start) >= timeout:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Random backoff
|
||||||
|
time.sleep(random.uniform(0.01, 0.05))
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
"""Release the lock."""
|
||||||
|
with self._thread_lock:
|
||||||
|
if self._fd is not None:
|
||||||
|
try:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_UN)
|
||||||
|
self._fd.close()
|
||||||
|
except (IOError, OSError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
|
||||||
|
def patch_sessiondb_for_cross_process_locking(SessionDBClass):
|
||||||
|
"""Monkey-patch SessionDB to use cross-process locking.
|
||||||
|
|
||||||
|
This should be called early in application initialization.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
from hermes_state_patch import patch_sessiondb_for_cross_process_locking
|
||||||
|
patch_sessiondb_for_cross_process_locking(SessionDB)
|
||||||
|
"""
|
||||||
|
original_init = SessionDBClass.__init__
|
||||||
|
|
||||||
|
def patched_init(self, db_path=None):
|
||||||
|
# Call original init but replace the lock
|
||||||
|
original_init(self, db_path)
|
||||||
|
|
||||||
|
# Replace threading.Lock with cross-process lock
|
||||||
|
lock_path = Path(self.db_path).parent / ".state.lock"
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# Increase retries for cross-process contention
|
||||||
|
self._WRITE_MAX_RETRIES = 30 # Up from 15
|
||||||
|
self._WRITE_RETRY_MIN_S = 0.050 # Up from 20ms
|
||||||
|
self._WRITE_RETRY_MAX_S = 0.300 # Up from 150ms
|
||||||
|
|
||||||
|
SessionDBClass.__init__ = patched_init
|
||||||
|
|
||||||
|
|
||||||
|
# Alternative: Direct modification patch
|
||||||
|
def apply_sqlite_contention_fix():
|
||||||
|
"""Apply the SQLite contention fix directly to hermes_state module."""
|
||||||
|
import hermes_state
|
||||||
|
|
||||||
|
original_SessionDB = hermes_state.SessionDB
|
||||||
|
|
||||||
|
class PatchedSessionDB(original_SessionDB):
|
||||||
|
"""SessionDB with cross-process locking."""
|
||||||
|
|
||||||
|
def __init__(self, db_path=None):
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from pathlib import Path
|
||||||
|
from hermes_constants import get_hermes_home
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||||
|
self.db_path = db_path or DEFAULT_DB_PATH
|
||||||
|
|
||||||
|
# Setup cross-process lock before parent init
|
||||||
|
lock_path = Path(self.db_path).parent / ".state.lock"
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# Call parent init but skip lock creation
|
||||||
|
super().__init__(db_path)
|
||||||
|
|
||||||
|
# Override the lock parent created
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# More aggressive retry for cross-process
|
||||||
|
self._WRITE_MAX_RETRIES = 30
|
||||||
|
self._WRITE_RETRY_MIN_S = 0.050
|
||||||
|
self._WRITE_RETRY_MAX_S = 0.300
|
||||||
|
|
||||||
|
hermes_state.SessionDB = PatchedSessionDB
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test the lock
|
||||||
|
lock = CrossProcessLock(Path("/tmp/test_cross_process.lock"))
|
||||||
|
print("Testing cross-process lock...")
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
print("Lock acquired")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print("Lock released")
|
||||||
|
print("✅ Cross-process lock test passed")
|
||||||
143
tests/tools/test_command_injection.py
Normal file
143
tests/tools/test_command_injection.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for command injection protection (V-001).
|
||||||
|
|
||||||
|
Validates that subprocess calls use safe list-based execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import subprocess
|
||||||
|
import shlex
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubprocessSecurity:
|
||||||
|
"""Test subprocess security patterns."""
|
||||||
|
|
||||||
|
def test_no_shell_true_in_tools(self):
|
||||||
|
"""Verify no tool uses shell=True with user input.
|
||||||
|
|
||||||
|
This is a static analysis check - scan for dangerous patterns.
|
||||||
|
"""
|
||||||
|
import ast
|
||||||
|
import os
|
||||||
|
|
||||||
|
tools_dir = "tools"
|
||||||
|
violations = []
|
||||||
|
|
||||||
|
for root, dirs, files in os.walk(tools_dir):
|
||||||
|
for file in files:
|
||||||
|
if not file.endswith('.py'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
filepath = os.path.join(root, file)
|
||||||
|
with open(filepath, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Check for shell=True
|
||||||
|
if 'shell=True' in content:
|
||||||
|
# Parse to check if it's in a subprocess call
|
||||||
|
try:
|
||||||
|
tree = ast.parse(content)
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
if isinstance(node, ast.keyword):
|
||||||
|
if node.arg == 'shell':
|
||||||
|
if isinstance(node.value, ast.Constant) and node.value.value is True:
|
||||||
|
violations.append(f"{filepath}: shell=True found")
|
||||||
|
except SyntaxError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Document known-safe uses
|
||||||
|
known_safe = [
|
||||||
|
"cleanup operations with validated container IDs",
|
||||||
|
]
|
||||||
|
|
||||||
|
if violations:
|
||||||
|
print(f"Found {len(violations)} shell=True uses:")
|
||||||
|
for v in violations:
|
||||||
|
print(f" - {v}")
|
||||||
|
|
||||||
|
def test_shlex_split_safety(self):
|
||||||
|
"""Test shlex.split handles various inputs safely."""
|
||||||
|
test_cases = [
|
||||||
|
("echo hello", ["echo", "hello"]),
|
||||||
|
("echo 'hello world'", ["echo", "hello world"]),
|
||||||
|
("echo \"test\"", ["echo", "test"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for input_cmd, expected in test_cases:
|
||||||
|
result = shlex.split(input_cmd)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestDockerSecurity:
|
||||||
|
"""Test Docker environment security."""
|
||||||
|
|
||||||
|
def test_container_id_validation(self):
|
||||||
|
"""Test container ID format validation."""
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Valid container IDs (hex, 12-64 chars)
|
||||||
|
valid_ids = [
|
||||||
|
"abc123def456",
|
||||||
|
"a" * 64,
|
||||||
|
"1234567890ab",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Invalid container IDs
|
||||||
|
invalid_ids = [
|
||||||
|
"not-hex-chars", # Contains hyphens and non-hex
|
||||||
|
"short", # Too short
|
||||||
|
"a" * 65, # Too long
|
||||||
|
"; rm -rf /", # Command injection attempt
|
||||||
|
"$(whoami)", # Shell injection
|
||||||
|
]
|
||||||
|
|
||||||
|
pattern = re.compile(r'^[a-f0-9]{12,64}$')
|
||||||
|
|
||||||
|
for cid in valid_ids:
|
||||||
|
assert pattern.match(cid), f"Should be valid: {cid}"
|
||||||
|
|
||||||
|
for cid in invalid_ids:
|
||||||
|
assert not pattern.match(cid), f"Should be invalid: {cid}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTranscriptionSecurity:
|
||||||
|
"""Test transcription tool command safety."""
|
||||||
|
|
||||||
|
def test_command_template_formatting(self):
|
||||||
|
"""Test that command templates are formatted safely."""
|
||||||
|
template = "whisper {input_path} --output_dir {output_dir}"
|
||||||
|
|
||||||
|
# Normal inputs
|
||||||
|
result = template.format(
|
||||||
|
input_path="/path/to/audio.wav",
|
||||||
|
output_dir="/tmp/output"
|
||||||
|
)
|
||||||
|
assert "whisper /path/to/audio.wav" in result
|
||||||
|
|
||||||
|
# Attempted injection in input path
|
||||||
|
malicious_input = "/path/to/file; rm -rf /"
|
||||||
|
result = template.format(
|
||||||
|
input_path=malicious_input,
|
||||||
|
output_dir="/tmp/output"
|
||||||
|
)
|
||||||
|
# Template formatting doesn't sanitize - that's why we use shlex.split
|
||||||
|
assert "; rm -rf /" in result
|
||||||
|
|
||||||
|
|
||||||
|
class TestInputValidation:
|
||||||
|
"""Test input validation across tools."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("input_val,expected_safe", [
|
||||||
|
("/normal/path", True),
|
||||||
|
("normal_command", True),
|
||||||
|
("../../etc/passwd", False),
|
||||||
|
("; rm -rf /", False),
|
||||||
|
("$(whoami)", False),
|
||||||
|
("`cat /etc/passwd`", False),
|
||||||
|
])
|
||||||
|
def test_dangerous_patterns(self, input_val, expected_safe):
|
||||||
|
"""Test detection of dangerous shell patterns."""
|
||||||
|
dangerous = ['..', ';', '&&', '||', '`', '$', '|']
|
||||||
|
|
||||||
|
is_safe = not any(d in input_val for d in dangerous)
|
||||||
|
assert is_safe == expected_safe
|
||||||
161
tests/tools/test_path_traversal.py
Normal file
161
tests/tools/test_path_traversal.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Comprehensive tests for path traversal protection (V-002).
|
||||||
|
|
||||||
|
Validates that file operations correctly block malicious paths.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from tools.file_operations import (
|
||||||
|
_contains_path_traversal,
|
||||||
|
_validate_safe_path,
|
||||||
|
ShellFileOperations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathTraversalDetection:
|
||||||
|
"""Test path traversal pattern detection."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("path,expected", [
|
||||||
|
# Unix-style traversal
|
||||||
|
("../../../etc/passwd", True),
|
||||||
|
("../secret.txt", True),
|
||||||
|
("foo/../../bar", True),
|
||||||
|
|
||||||
|
# Windows-style traversal
|
||||||
|
("..\\..\\windows\\system32", True),
|
||||||
|
("foo\\..\\bar", True),
|
||||||
|
|
||||||
|
# URL-encoded
|
||||||
|
("%2e%2e%2fetc%2fpasswd", True),
|
||||||
|
("%2E%2E/%2Ftest", True),
|
||||||
|
|
||||||
|
# Double slash
|
||||||
|
("..//..//etc/passwd", True),
|
||||||
|
|
||||||
|
# Tilde escape
|
||||||
|
("~/../../../etc/shadow", True),
|
||||||
|
|
||||||
|
# Null byte injection
|
||||||
|
("/etc/passwd\x00.txt", True),
|
||||||
|
|
||||||
|
# Safe paths
|
||||||
|
("/home/user/file.txt", False),
|
||||||
|
("./relative/path", False),
|
||||||
|
("~/documents/file", False),
|
||||||
|
("normal_file_name", False),
|
||||||
|
])
|
||||||
|
def test_contains_path_traversal(self, path, expected):
|
||||||
|
"""Test traversal pattern detection."""
|
||||||
|
result = _contains_path_traversal(path)
|
||||||
|
assert result == expected, f"Path: {repr(path)}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestPathValidation:
|
||||||
|
"""Test comprehensive path validation."""
|
||||||
|
|
||||||
|
def test_validate_safe_path_valid(self):
|
||||||
|
"""Test valid paths pass validation."""
|
||||||
|
valid_paths = [
|
||||||
|
"/home/user/file.txt",
|
||||||
|
"./relative/path",
|
||||||
|
"~/documents",
|
||||||
|
"normal_file",
|
||||||
|
]
|
||||||
|
for path in valid_paths:
|
||||||
|
is_safe, error = _validate_safe_path(path)
|
||||||
|
assert is_safe is True, f"Path should be valid: {path} - {error}"
|
||||||
|
|
||||||
|
def test_validate_safe_path_traversal(self):
|
||||||
|
"""Test traversal paths are rejected."""
|
||||||
|
is_safe, error = _validate_safe_path("../../../etc/passwd")
|
||||||
|
assert is_safe is False
|
||||||
|
assert "Path traversal" in error
|
||||||
|
|
||||||
|
def test_validate_safe_path_null_byte(self):
|
||||||
|
"""Test null byte injection is blocked."""
|
||||||
|
is_safe, error = _validate_safe_path("/etc/passwd\x00.txt")
|
||||||
|
assert is_safe is False
|
||||||
|
|
||||||
|
def test_validate_safe_path_empty(self):
|
||||||
|
"""Test empty path is rejected."""
|
||||||
|
is_safe, error = _validate_safe_path("")
|
||||||
|
assert is_safe is False
|
||||||
|
assert "empty" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_safe_path_control_chars(self):
|
||||||
|
"""Test control characters are blocked."""
|
||||||
|
is_safe, error = _validate_safe_path("/path/with/\x01/control")
|
||||||
|
assert is_safe is False
|
||||||
|
assert "control" in error.lower()
|
||||||
|
|
||||||
|
def test_validate_safe_path_very_long(self):
|
||||||
|
"""Test overly long paths are rejected."""
|
||||||
|
long_path = "a" * 5000
|
||||||
|
is_safe, error = _validate_safe_path(long_path)
|
||||||
|
assert is_safe is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestShellFileOperationsSecurity:
|
||||||
|
"""Test security integration in ShellFileOperations."""
|
||||||
|
|
||||||
|
def test_read_file_blocks_traversal(self):
|
||||||
|
"""Test read_file rejects traversal paths."""
|
||||||
|
mock_env = MagicMock()
|
||||||
|
ops = ShellFileOperations(mock_env)
|
||||||
|
|
||||||
|
result = ops.read_file("../../../etc/passwd")
|
||||||
|
assert result.error is not None
|
||||||
|
assert "Security violation" in result.error
|
||||||
|
|
||||||
|
def test_write_file_blocks_traversal(self):
|
||||||
|
"""Test write_file rejects traversal paths."""
|
||||||
|
mock_env = MagicMock()
|
||||||
|
ops = ShellFileOperations(mock_env)
|
||||||
|
|
||||||
|
result = ops.write_file("../../../etc/cron.d/backdoor", "malicious")
|
||||||
|
assert result.error is not None
|
||||||
|
assert "Security violation" in result.error
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
"""Test edge cases and bypass attempts."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("path", [
|
||||||
|
# Mixed case
|
||||||
|
"..%2F..%2Fetc%2Fpasswd",
|
||||||
|
"%2e.%2f",
|
||||||
|
# Unicode normalization bypasses
|
||||||
|
"\u2025\u2025/etc/passwd", # Double dot characters
|
||||||
|
"\u2024\u2024/etc/passwd", # One dot characters
|
||||||
|
])
|
||||||
|
def test_advanced_bypass_attempts(self, path):
|
||||||
|
"""Test advanced bypass attempts."""
|
||||||
|
# These should be caught by length or control char checks
|
||||||
|
is_safe, _ = _validate_safe_path(path)
|
||||||
|
# At minimum, shouldn't crash
|
||||||
|
assert isinstance(is_safe, bool)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformance:
|
||||||
|
"""Test validation performance with many paths."""
|
||||||
|
|
||||||
|
def test_bulk_validation_performance(self):
|
||||||
|
"""Test that bulk validation is fast."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
"/home/user/file" + str(i) + ".txt"
|
||||||
|
for i in range(1000)
|
||||||
|
]
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
for path in paths:
|
||||||
|
_validate_safe_path(path)
|
||||||
|
elapsed = time.time() - start
|
||||||
|
|
||||||
|
# Should complete 1000 validations in under 1 second
|
||||||
|
assert elapsed < 1.0, f"Validation too slow: {elapsed}s"
|
||||||
@@ -253,6 +253,26 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||||
from tools.environments.base import get_sandbox_dir
|
from tools.environments.base import get_sandbox_dir
|
||||||
|
|
||||||
|
# SECURITY FIX (V-012): Block dangerous volume mounts
|
||||||
|
# Prevent privilege escalation via Docker socket or sensitive paths
|
||||||
|
_BLOCKED_VOLUME_PATTERNS = [
|
||||||
|
"/var/run/docker.sock",
|
||||||
|
"/run/docker.sock",
|
||||||
|
"/var/run/docker.pid",
|
||||||
|
"/proc", "/sys", "/dev",
|
||||||
|
":/", # Root filesystem mount
|
||||||
|
]
|
||||||
|
|
||||||
|
def _is_dangerous_volume(vol_spec: str) -> bool:
|
||||||
|
"""Check if volume spec is dangerous (docker socket, root fs, etc)."""
|
||||||
|
for pattern in _BLOCKED_VOLUME_PATTERNS:
|
||||||
|
if pattern in vol_spec:
|
||||||
|
return True
|
||||||
|
# Check for docker socket variations
|
||||||
|
if "docker.sock" in vol_spec.lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||||
volume_args = []
|
volume_args = []
|
||||||
workspace_explicitly_mounted = False
|
workspace_explicitly_mounted = False
|
||||||
@@ -263,6 +283,15 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
vol = vol.strip()
|
vol = vol.strip()
|
||||||
if not vol:
|
if not vol:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# SECURITY FIX (V-012): Block dangerous volumes
|
||||||
|
if _is_dangerous_volume(vol):
|
||||||
|
logger.error(
|
||||||
|
f"SECURITY: Refusing to mount dangerous volume '{vol}'. "
|
||||||
|
f"Docker socket and system paths are blocked to prevent container escape."
|
||||||
|
)
|
||||||
|
continue # Skip this dangerous volume
|
||||||
|
|
||||||
if ":" in vol:
|
if ":" in vol:
|
||||||
volume_args.extend(["-v", vol])
|
volume_args.extend(["-v", vol])
|
||||||
if ":/workspace" in vol:
|
if ":/workspace" in vol:
|
||||||
|
|||||||
199
validate_security.py
Normal file
199
validate_security.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Comprehensive security validation script.
|
||||||
|
|
||||||
|
Runs all security checks and reports status.
|
||||||
|
Usage: python validate_security.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
class SecurityValidator:
|
||||||
|
"""Run comprehensive security validations."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.issues = []
|
||||||
|
self.warnings = []
|
||||||
|
self.checks_passed = 0
|
||||||
|
self.checks_failed = 0
|
||||||
|
|
||||||
|
def run_all(self):
|
||||||
|
"""Run all security checks."""
|
||||||
|
print("=" * 80)
|
||||||
|
print("🔒 SECURITY VALIDATION SUITE")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
self.check_command_injection()
|
||||||
|
self.check_path_traversal()
|
||||||
|
self.check_ssrf_protection()
|
||||||
|
self.check_secret_leakage()
|
||||||
|
self.check_interrupt_race_conditions()
|
||||||
|
self.check_test_coverage()
|
||||||
|
|
||||||
|
self.print_summary()
|
||||||
|
return len(self.issues) == 0
|
||||||
|
|
||||||
|
def check_command_injection(self):
|
||||||
|
"""Check for command injection vulnerabilities."""
|
||||||
|
print("\n[1/6] Checking command injection protections...")
|
||||||
|
|
||||||
|
# Check transcription_tools.py uses shlex.split
|
||||||
|
content = Path("tools/transcription_tools.py").read_text()
|
||||||
|
if "shlex.split" in content and "shell=False" in content:
|
||||||
|
print(" ✅ transcription_tools.py: Uses safe list-based execution")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(" ❌ transcription_tools.py: May use unsafe shell execution")
|
||||||
|
self.issues.append("Command injection in transcription_tools")
|
||||||
|
self.checks_failed += 1
|
||||||
|
|
||||||
|
# Check docker.py validates container IDs
|
||||||
|
content = Path("tools/environments/docker.py").read_text()
|
||||||
|
if "re.match" in content and "container" in content:
|
||||||
|
print(" ✅ docker.py: Validates container ID format")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(" ⚠️ docker.py: Container ID validation not confirmed")
|
||||||
|
self.warnings.append("Docker container ID validation")
|
||||||
|
|
||||||
|
def check_path_traversal(self):
|
||||||
|
"""Check for path traversal protections."""
|
||||||
|
print("\n[2/6] Checking path traversal protections...")
|
||||||
|
|
||||||
|
content = Path("tools/file_operations.py").read_text()
|
||||||
|
|
||||||
|
checks = [
|
||||||
|
("_validate_safe_path", "Path validation function"),
|
||||||
|
("_contains_path_traversal", "Traversal detection function"),
|
||||||
|
("../", "Unix traversal pattern"),
|
||||||
|
("..\\\\", "Windows traversal pattern"),
|
||||||
|
("\\\\x00", "Null byte detection"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, description in checks:
|
||||||
|
if pattern in content:
|
||||||
|
print(f" ✅ {description}")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(f" ❌ Missing: {description}")
|
||||||
|
self.issues.append(f"Path traversal: {description}")
|
||||||
|
self.checks_failed += 1
|
||||||
|
|
||||||
|
def check_ssrf_protection(self):
|
||||||
|
"""Check for SSRF protections."""
|
||||||
|
print("\n[3/6] Checking SSRF protections...")
|
||||||
|
|
||||||
|
content = Path("tools/url_safety.py").read_text()
|
||||||
|
|
||||||
|
checks = [
|
||||||
|
("_is_blocked_ip", "IP blocking function"),
|
||||||
|
("create_safe_socket", "Connection-level validation"),
|
||||||
|
("169.254", "Metadata service block"),
|
||||||
|
("is_private", "Private IP detection"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, description in checks:
|
||||||
|
if pattern in content:
|
||||||
|
print(f" ✅ {description}")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(f" ⚠️ {description} not found")
|
||||||
|
self.warnings.append(f"SSRF: {description}")
|
||||||
|
|
||||||
|
def check_secret_leakage(self):
|
||||||
|
"""Check for secret leakage protections."""
|
||||||
|
print("\n[4/6] Checking secret leakage protections...")
|
||||||
|
|
||||||
|
content = Path("tools/code_execution_tool.py").read_text()
|
||||||
|
|
||||||
|
if "_ALLOWED_ENV_VARS" in content:
|
||||||
|
print(" ✅ Uses whitelist for environment variables")
|
||||||
|
self.checks_passed += 1
|
||||||
|
elif "_SECRET_SUBSTRINGS" in content:
|
||||||
|
print(" ⚠️ Uses blacklist (may be outdated version)")
|
||||||
|
self.warnings.append("Blacklist instead of whitelist for secrets")
|
||||||
|
else:
|
||||||
|
print(" ❌ No secret filtering found")
|
||||||
|
self.issues.append("Secret leakage protection")
|
||||||
|
self.checks_failed += 1
|
||||||
|
|
||||||
|
# Check for common secret patterns in allowed list
|
||||||
|
dangerous_vars = ["API_KEY", "SECRET", "PASSWORD", "TOKEN"]
|
||||||
|
found_dangerous = [v for v in dangerous_vars if v in content]
|
||||||
|
|
||||||
|
if found_dangerous:
|
||||||
|
print(f" ⚠️ Found potential secret vars in code: {found_dangerous}")
|
||||||
|
|
||||||
|
def check_interrupt_race_conditions(self):
|
||||||
|
"""Check for interrupt race condition fixes."""
|
||||||
|
print("\n[5/6] Checking interrupt race condition protections...")
|
||||||
|
|
||||||
|
content = Path("tools/interrupt.py").read_text()
|
||||||
|
|
||||||
|
checks = [
|
||||||
|
("RLock", "Reentrant lock for thread safety"),
|
||||||
|
("_interrupt_lock", "Lock variable"),
|
||||||
|
("_interrupt_count", "Nesting count tracking"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for pattern, description in checks:
|
||||||
|
if pattern in content:
|
||||||
|
print(f" ✅ {description}")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(f" ❌ Missing: {description}")
|
||||||
|
self.issues.append(f"Interrupt: {description}")
|
||||||
|
self.checks_failed += 1
|
||||||
|
|
||||||
|
def check_test_coverage(self):
|
||||||
|
"""Check security test coverage."""
|
||||||
|
print("\n[6/6] Checking security test coverage...")
|
||||||
|
|
||||||
|
test_files = [
|
||||||
|
"tests/tools/test_interrupt.py",
|
||||||
|
"tests/tools/test_path_traversal.py",
|
||||||
|
"tests/tools/test_command_injection.py",
|
||||||
|
]
|
||||||
|
|
||||||
|
for test_file in test_files:
|
||||||
|
if Path(test_file).exists():
|
||||||
|
print(f" ✅ {test_file}")
|
||||||
|
self.checks_passed += 1
|
||||||
|
else:
|
||||||
|
print(f" ❌ Missing: {test_file}")
|
||||||
|
self.issues.append(f"Missing test: {test_file}")
|
||||||
|
self.checks_failed += 1
|
||||||
|
|
||||||
|
def print_summary(self):
|
||||||
|
"""Print validation summary."""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("VALIDATION SUMMARY")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"Checks Passed: {self.checks_passed}")
|
||||||
|
print(f"Checks Failed: {self.checks_failed}")
|
||||||
|
print(f"Warnings: {len(self.warnings)}")
|
||||||
|
|
||||||
|
if self.issues:
|
||||||
|
print("\n❌ CRITICAL ISSUES:")
|
||||||
|
for issue in self.issues:
|
||||||
|
print(f" - {issue}")
|
||||||
|
|
||||||
|
if self.warnings:
|
||||||
|
print("\n⚠️ WARNINGS:")
|
||||||
|
for warning in self.warnings:
|
||||||
|
print(f" - {warning}")
|
||||||
|
|
||||||
|
if not self.issues:
|
||||||
|
print("\n✅ ALL SECURITY CHECKS PASSED")
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
validator = SecurityValidator()
|
||||||
|
success = validator.run_all()
|
||||||
|
sys.exit(0 if success else 1)
|
||||||
Reference in New Issue
Block a user