Compare commits
10 Commits
tests/secu
...
security/a
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e3f5072f6 | |||
| 5936745636 | |||
| cfaf6c827e | |||
| cf1afb07f2 | |||
| ed32487cbe | |||
| 37c5e672b5 | |||
| cfcffd38ab | |||
| 0b49540db3 | |||
| ffa8405cfb | |||
| cc1b9e8054 |
@@ -241,6 +241,43 @@ else:
|
|||||||
security_headers_middleware = None # type: ignore[assignment]
|
security_headers_middleware = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
# SECURITY FIX (V-016): Rate limiting middleware
|
||||||
|
if AIOHTTP_AVAILABLE:
|
||||||
|
@web.middleware
|
||||||
|
async def rate_limit_middleware(request, handler):
|
||||||
|
"""Apply rate limiting per client IP.
|
||||||
|
|
||||||
|
Returns 429 Too Many Requests if rate limit exceeded.
|
||||||
|
Configurable via API_SERVER_RATE_LIMIT env var (requests per minute).
|
||||||
|
"""
|
||||||
|
# Skip rate limiting for health checks
|
||||||
|
if request.path == "/health":
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
# Get client IP (respecting X-Forwarded-For if behind proxy)
|
||||||
|
client_ip = request.headers.get("X-Forwarded-For", request.remote)
|
||||||
|
if client_ip and "," in client_ip:
|
||||||
|
client_ip = client_ip.split(",")[0].strip()
|
||||||
|
|
||||||
|
limiter = _get_rate_limiter()
|
||||||
|
if not limiter.acquire(client_ip):
|
||||||
|
retry_after = limiter.get_retry_after(client_ip)
|
||||||
|
logger.warning(f"Rate limit exceeded for {client_ip}")
|
||||||
|
return web.json_response(
|
||||||
|
_openai_error(
|
||||||
|
f"Rate limit exceeded. Try again in {retry_after} seconds.",
|
||||||
|
err_type="rate_limit_error",
|
||||||
|
code="rate_limit_exceeded"
|
||||||
|
),
|
||||||
|
status=429,
|
||||||
|
headers={"Retry-After": str(retry_after)}
|
||||||
|
)
|
||||||
|
|
||||||
|
return await handler(request)
|
||||||
|
else:
|
||||||
|
rate_limit_middleware = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
class _IdempotencyCache:
|
class _IdempotencyCache:
|
||||||
"""In-memory idempotency cache with TTL and basic LRU semantics."""
|
"""In-memory idempotency cache with TTL and basic LRU semantics."""
|
||||||
def __init__(self, max_items: int = 1000, ttl_seconds: int = 300):
|
def __init__(self, max_items: int = 1000, ttl_seconds: int = 300):
|
||||||
@@ -273,6 +310,59 @@ class _IdempotencyCache:
|
|||||||
_idem_cache = _IdempotencyCache()
|
_idem_cache = _IdempotencyCache()
|
||||||
|
|
||||||
|
|
||||||
|
# SECURITY FIX (V-016): Rate limiting
|
||||||
|
class _RateLimiter:
|
||||||
|
"""Token bucket rate limiter per client IP.
|
||||||
|
|
||||||
|
Default: 100 requests per minute per IP.
|
||||||
|
Configurable via API_SERVER_RATE_LIMIT env var (requests per minute).
|
||||||
|
"""
|
||||||
|
def __init__(self, requests_per_minute: int = 100):
|
||||||
|
from collections import defaultdict
|
||||||
|
self._buckets = defaultdict(lambda: {"tokens": requests_per_minute, "last": 0})
|
||||||
|
self._rate = requests_per_minute / 60.0 # tokens per second
|
||||||
|
self._max_tokens = requests_per_minute
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _get_bucket(self, key: str) -> dict:
|
||||||
|
import time
|
||||||
|
with self._lock:
|
||||||
|
bucket = self._buckets[key]
|
||||||
|
now = time.time()
|
||||||
|
elapsed = now - bucket["last"]
|
||||||
|
bucket["last"] = now
|
||||||
|
# Add tokens based on elapsed time
|
||||||
|
bucket["tokens"] = min(
|
||||||
|
self._max_tokens,
|
||||||
|
bucket["tokens"] + elapsed * self._rate
|
||||||
|
)
|
||||||
|
return bucket
|
||||||
|
|
||||||
|
def acquire(self, key: str) -> bool:
|
||||||
|
"""Try to acquire a token. Returns True if allowed, False if rate limited."""
|
||||||
|
bucket = self._get_bucket(key)
|
||||||
|
with self._lock:
|
||||||
|
if bucket["tokens"] >= 1:
|
||||||
|
bucket["tokens"] -= 1
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_retry_after(self, key: str) -> int:
|
||||||
|
"""Get seconds until next token is available."""
|
||||||
|
return 1 # Simplified - return 1 second
|
||||||
|
|
||||||
|
|
||||||
|
_rate_limiter = None
|
||||||
|
|
||||||
|
def _get_rate_limiter() -> _RateLimiter:
|
||||||
|
global _rate_limiter
|
||||||
|
if _rate_limiter is None:
|
||||||
|
# Parse rate limit from env (default 100 req/min)
|
||||||
|
rate_limit = int(os.getenv("API_SERVER_RATE_LIMIT", "100"))
|
||||||
|
_rate_limiter = _RateLimiter(rate_limit)
|
||||||
|
return _rate_limiter
|
||||||
|
|
||||||
|
|
||||||
def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str:
|
def _make_request_fingerprint(body: Dict[str, Any], keys: List[str]) -> str:
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
subset = {k: body.get(k) for k in keys}
|
subset = {k: body.get(k) for k in keys}
|
||||||
@@ -292,7 +382,29 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
extra = config.extra or {}
|
extra = config.extra or {}
|
||||||
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
self._host: str = extra.get("host", os.getenv("API_SERVER_HOST", DEFAULT_HOST))
|
||||||
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
self._port: int = int(extra.get("port", os.getenv("API_SERVER_PORT", str(DEFAULT_PORT))))
|
||||||
|
|
||||||
|
# SECURITY FIX (V-009): Fail-secure default for API key
|
||||||
|
# Previously: Empty API key allowed all requests (dangerous default)
|
||||||
|
# Now: Require explicit "allow_unauthenticated" setting to disable auth
|
||||||
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
self._api_key: str = extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
||||||
|
self._allow_unauthenticated: bool = extra.get(
|
||||||
|
"allow_unauthenticated",
|
||||||
|
os.getenv("API_SERVER_ALLOW_UNAUTHENTICATED", "").lower() in ("true", "1", "yes")
|
||||||
|
)
|
||||||
|
|
||||||
|
# SECURITY: Log warning if no API key configured
|
||||||
|
if not self._api_key and not self._allow_unauthenticated:
|
||||||
|
logger.warning(
|
||||||
|
"API_SERVER_KEY not configured. All requests will be rejected. "
|
||||||
|
"Set API_SERVER_ALLOW_UNAUTHENTICATED=true for local-only use, "
|
||||||
|
"or configure API_SERVER_KEY for production."
|
||||||
|
)
|
||||||
|
elif not self._api_key and self._allow_unauthenticated:
|
||||||
|
logger.warning(
|
||||||
|
"API_SERVER running without authentication. "
|
||||||
|
"This is only safe for local-only deployments."
|
||||||
|
)
|
||||||
|
|
||||||
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
self._cors_origins: tuple[str, ...] = self._parse_cors_origins(
|
||||||
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
extra.get("cors_origins", os.getenv("API_SERVER_CORS_ORIGINS", "")),
|
||||||
)
|
)
|
||||||
@@ -317,15 +429,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
return tuple(str(item).strip() for item in items if str(item).strip())
|
return tuple(str(item).strip() for item in items if str(item).strip())
|
||||||
|
|
||||||
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
def _cors_headers_for_origin(self, origin: str) -> Optional[Dict[str, str]]:
|
||||||
"""Return CORS headers for an allowed browser origin."""
|
"""Return CORS headers for an allowed browser origin.
|
||||||
|
|
||||||
|
SECURITY FIX (V-008): Never allow wildcard "*" with credentials.
|
||||||
|
If "*" is configured, we reject the request to prevent security issues.
|
||||||
|
"""
|
||||||
if not origin or not self._cors_origins:
|
if not origin or not self._cors_origins:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# SECURITY FIX (V-008): Reject wildcard CORS origins
|
||||||
|
# Wildcard with credentials is a security vulnerability
|
||||||
if "*" in self._cors_origins:
|
if "*" in self._cors_origins:
|
||||||
headers = dict(_CORS_HEADERS)
|
logger.warning(
|
||||||
headers["Access-Control-Allow-Origin"] = "*"
|
"CORS wildcard '*' is not allowed for security reasons. "
|
||||||
headers["Access-Control-Max-Age"] = "600"
|
"Please configure specific origins in API_SERVER_CORS_ORIGINS."
|
||||||
return headers
|
)
|
||||||
|
return None # Reject wildcard - too dangerous
|
||||||
|
|
||||||
if origin not in self._cors_origins:
|
if origin not in self._cors_origins:
|
||||||
return None
|
return None
|
||||||
@@ -355,10 +474,22 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
Validate Bearer token from Authorization header.
|
Validate Bearer token from Authorization header.
|
||||||
|
|
||||||
Returns None if auth is OK, or a 401 web.Response on failure.
|
Returns None if auth is OK, or a 401 web.Response on failure.
|
||||||
If no API key is configured, all requests are allowed.
|
|
||||||
|
SECURITY FIX (V-009): Fail-secure default
|
||||||
|
- If no API key is configured AND allow_unauthenticated is not set,
|
||||||
|
all requests are rejected (secure by default)
|
||||||
|
- Only allow unauthenticated requests if explicitly configured
|
||||||
"""
|
"""
|
||||||
if not self._api_key:
|
# SECURITY: Fail-secure default - reject if no key and not explicitly allowed
|
||||||
return None # No key configured — allow all (local-only use)
|
if not self._api_key and not self._allow_unauthenticated:
|
||||||
|
return web.json_response(
|
||||||
|
{"error": {"message": "Authentication required. Configure API_SERVER_KEY or set API_SERVER_ALLOW_UNAUTHENTICATED=true for local development.", "type": "authentication_error", "code": "auth_required"}},
|
||||||
|
status=401,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow unauthenticated requests only if explicitly configured
|
||||||
|
if not self._api_key and self._allow_unauthenticated:
|
||||||
|
return None # Explicitly allowed for local-only use
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
auth_header = request.headers.get("Authorization", "")
|
||||||
if auth_header.startswith("Bearer "):
|
if auth_header.startswith("Bearer "):
|
||||||
@@ -1241,7 +1372,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware) if mw is not None]
|
# SECURITY FIX (V-016): Add rate limiting middleware
|
||||||
|
mws = [mw for mw in (cors_middleware, body_limit_middleware, security_headers_middleware, rate_limit_middleware) if mw is not None]
|
||||||
self._app = web.Application(middlewares=mws)
|
self._app = web.Application(middlewares=mws)
|
||||||
self._app["api_server_adapter"] = self
|
self._app["api_server_adapter"] = self
|
||||||
self._app.router.add_get("/health", self._handle_health)
|
self._app.router.add_get("/health", self._handle_health)
|
||||||
|
|||||||
167
hermes_state_patch.py
Normal file
167
hermes_state_patch.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
"""SQLite State Store patch for cross-process locking.
|
||||||
|
|
||||||
|
Addresses Issue #52: SQLite global write lock causes contention.
|
||||||
|
|
||||||
|
The problem: Multiple hermes processes (gateway + CLI + worktree agents)
|
||||||
|
share one state.db, but each process has its own threading.Lock.
|
||||||
|
This patch adds file-based locking for cross-process coordination.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fcntl
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class CrossProcessLock:
|
||||||
|
"""File-based lock for cross-process SQLite coordination.
|
||||||
|
|
||||||
|
Uses flock() on Unix and LockFile on Windows for atomic
|
||||||
|
cross-process locking. Falls back to threading.Lock if
|
||||||
|
file locking fails.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, lock_path: Path):
|
||||||
|
self.lock_path = lock_path
|
||||||
|
self.lock_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._fd = None
|
||||||
|
self._thread_lock = threading.Lock()
|
||||||
|
|
||||||
|
def acquire(self, blocking: bool = True, timeout: float = None) -> bool:
|
||||||
|
"""Acquire the cross-process lock.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
blocking: If True, block until lock is acquired
|
||||||
|
timeout: Maximum time to wait (None = forever)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if lock acquired, False if timeout
|
||||||
|
"""
|
||||||
|
with self._thread_lock:
|
||||||
|
if self._fd is not None:
|
||||||
|
return True # Already held
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
self._fd = open(self.lock_path, "w")
|
||||||
|
if blocking:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_EX)
|
||||||
|
else:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||||
|
return True
|
||||||
|
except (IOError, OSError) as e:
|
||||||
|
if self._fd:
|
||||||
|
self._fd.close()
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
if not blocking:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if timeout and (time.time() - start) >= timeout:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Random backoff
|
||||||
|
time.sleep(random.uniform(0.01, 0.05))
|
||||||
|
|
||||||
|
def release(self):
|
||||||
|
"""Release the lock."""
|
||||||
|
with self._thread_lock:
|
||||||
|
if self._fd is not None:
|
||||||
|
try:
|
||||||
|
fcntl.flock(self._fd.fileno(), fcntl.LOCK_UN)
|
||||||
|
self._fd.close()
|
||||||
|
except (IOError, OSError):
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._fd = None
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.acquire()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
self.release()
|
||||||
|
|
||||||
|
|
||||||
|
def patch_sessiondb_for_cross_process_locking(SessionDBClass):
|
||||||
|
"""Monkey-patch SessionDB to use cross-process locking.
|
||||||
|
|
||||||
|
This should be called early in application initialization.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from hermes_state import SessionDB
|
||||||
|
from hermes_state_patch import patch_sessiondb_for_cross_process_locking
|
||||||
|
patch_sessiondb_for_cross_process_locking(SessionDB)
|
||||||
|
"""
|
||||||
|
original_init = SessionDBClass.__init__
|
||||||
|
|
||||||
|
def patched_init(self, db_path=None):
|
||||||
|
# Call original init but replace the lock
|
||||||
|
original_init(self, db_path)
|
||||||
|
|
||||||
|
# Replace threading.Lock with cross-process lock
|
||||||
|
lock_path = Path(self.db_path).parent / ".state.lock"
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# Increase retries for cross-process contention
|
||||||
|
self._WRITE_MAX_RETRIES = 30 # Up from 15
|
||||||
|
self._WRITE_RETRY_MIN_S = 0.050 # Up from 20ms
|
||||||
|
self._WRITE_RETRY_MAX_S = 0.300 # Up from 150ms
|
||||||
|
|
||||||
|
SessionDBClass.__init__ = patched_init
|
||||||
|
|
||||||
|
|
||||||
|
# Alternative: Direct modification patch
|
||||||
|
def apply_sqlite_contention_fix():
|
||||||
|
"""Apply the SQLite contention fix directly to hermes_state module."""
|
||||||
|
import hermes_state
|
||||||
|
|
||||||
|
original_SessionDB = hermes_state.SessionDB
|
||||||
|
|
||||||
|
class PatchedSessionDB(original_SessionDB):
|
||||||
|
"""SessionDB with cross-process locking."""
|
||||||
|
|
||||||
|
def __init__(self, db_path=None):
|
||||||
|
# Import here to avoid circular imports
|
||||||
|
from pathlib import Path
|
||||||
|
from hermes_constants import get_hermes_home
|
||||||
|
|
||||||
|
DEFAULT_DB_PATH = get_hermes_home() / "state.db"
|
||||||
|
self.db_path = db_path or DEFAULT_DB_PATH
|
||||||
|
|
||||||
|
# Setup cross-process lock before parent init
|
||||||
|
lock_path = Path(self.db_path).parent / ".state.lock"
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# Call parent init but skip lock creation
|
||||||
|
super().__init__(db_path)
|
||||||
|
|
||||||
|
# Override the lock parent created
|
||||||
|
self._lock = CrossProcessLock(lock_path)
|
||||||
|
|
||||||
|
# More aggressive retry for cross-process
|
||||||
|
self._WRITE_MAX_RETRIES = 30
|
||||||
|
self._WRITE_RETRY_MIN_S = 0.050
|
||||||
|
self._WRITE_RETRY_MAX_S = 0.300
|
||||||
|
|
||||||
|
hermes_state.SessionDB = PatchedSessionDB
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Test the lock
|
||||||
|
lock = CrossProcessLock(Path("/tmp/test_cross_process.lock"))
|
||||||
|
print("Testing cross-process lock...")
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
print("Lock acquired")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
print("Lock released")
|
||||||
|
print("✅ Cross-process lock test passed")
|
||||||
@@ -170,6 +170,9 @@ def _resolve_cdp_override(cdp_url: str) -> str:
|
|||||||
For discovery-style endpoints we fetch /json/version and return the
|
For discovery-style endpoints we fetch /json/version and return the
|
||||||
webSocketDebuggerUrl so downstream tools always receive a concrete browser
|
webSocketDebuggerUrl so downstream tools always receive a concrete browser
|
||||||
websocket instead of an ambiguous host:port URL.
|
websocket instead of an ambiguous host:port URL.
|
||||||
|
|
||||||
|
SECURITY FIX (V-010): Validates URLs before fetching to prevent SSRF.
|
||||||
|
Only allows localhost/private network addresses for CDP connections.
|
||||||
"""
|
"""
|
||||||
raw = (cdp_url or "").strip()
|
raw = (cdp_url or "").strip()
|
||||||
if not raw:
|
if not raw:
|
||||||
@@ -191,6 +194,35 @@ def _resolve_cdp_override(cdp_url: str) -> str:
|
|||||||
else:
|
else:
|
||||||
version_url = discovery_url.rstrip("/") + "/json/version"
|
version_url = discovery_url.rstrip("/") + "/json/version"
|
||||||
|
|
||||||
|
# SECURITY FIX (V-010): Validate URL before fetching
|
||||||
|
# Only allow localhost and private networks for CDP
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
parsed = urlparse(version_url)
|
||||||
|
hostname = parsed.hostname or ""
|
||||||
|
|
||||||
|
# Allow only safe hostnames for CDP
|
||||||
|
allowed_hostnames = ["localhost", "127.0.0.1", "0.0.0.0", "::1"]
|
||||||
|
if hostname not in allowed_hostnames:
|
||||||
|
# Check if it's a private IP
|
||||||
|
try:
|
||||||
|
import ipaddress
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
if not (ip.is_private or ip.is_loopback):
|
||||||
|
logger.error(
|
||||||
|
"SECURITY: Rejecting CDP URL '%s' - only localhost and private "
|
||||||
|
"networks are allowed to prevent SSRF attacks.",
|
||||||
|
raw
|
||||||
|
)
|
||||||
|
return raw # Return original without fetching
|
||||||
|
except ValueError:
|
||||||
|
# Not an IP - reject unknown hostnames
|
||||||
|
logger.error(
|
||||||
|
"SECURITY: Rejecting CDP URL '%s' - unknown hostname '%s'. "
|
||||||
|
"Only localhost and private IPs are allowed.",
|
||||||
|
raw, hostname
|
||||||
|
)
|
||||||
|
return raw
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(version_url, timeout=10)
|
response = requests.get(version_url, timeout=10)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|||||||
@@ -253,6 +253,26 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
# mode uses tmpfs (ephemeral, fast, gone on cleanup).
|
||||||
from tools.environments.base import get_sandbox_dir
|
from tools.environments.base import get_sandbox_dir
|
||||||
|
|
||||||
|
# SECURITY FIX (V-012): Block dangerous volume mounts
|
||||||
|
# Prevent privilege escalation via Docker socket or sensitive paths
|
||||||
|
_BLOCKED_VOLUME_PATTERNS = [
|
||||||
|
"/var/run/docker.sock",
|
||||||
|
"/run/docker.sock",
|
||||||
|
"/var/run/docker.pid",
|
||||||
|
"/proc", "/sys", "/dev",
|
||||||
|
":/", # Root filesystem mount
|
||||||
|
]
|
||||||
|
|
||||||
|
def _is_dangerous_volume(vol_spec: str) -> bool:
|
||||||
|
"""Check if volume spec is dangerous (docker socket, root fs, etc)."""
|
||||||
|
for pattern in _BLOCKED_VOLUME_PATTERNS:
|
||||||
|
if pattern in vol_spec:
|
||||||
|
return True
|
||||||
|
# Check for docker socket variations
|
||||||
|
if "docker.sock" in vol_spec.lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
# User-configured volume mounts (from config.yaml docker_volumes)
|
# User-configured volume mounts (from config.yaml docker_volumes)
|
||||||
volume_args = []
|
volume_args = []
|
||||||
workspace_explicitly_mounted = False
|
workspace_explicitly_mounted = False
|
||||||
@@ -263,6 +283,15 @@ class DockerEnvironment(BaseEnvironment):
|
|||||||
vol = vol.strip()
|
vol = vol.strip()
|
||||||
if not vol:
|
if not vol:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# SECURITY FIX (V-012): Block dangerous volumes
|
||||||
|
if _is_dangerous_volume(vol):
|
||||||
|
logger.error(
|
||||||
|
f"SECURITY: Refusing to mount dangerous volume '{vol}'. "
|
||||||
|
f"Docker socket and system paths are blocked to prevent container escape."
|
||||||
|
)
|
||||||
|
continue # Skip this dangerous volume
|
||||||
|
|
||||||
if ":" in vol:
|
if ":" in vol:
|
||||||
volume_args.extend(["-v", vol])
|
volume_args.extend(["-v", vol])
|
||||||
if ":/workspace" in vol:
|
if ":/workspace" in vol:
|
||||||
|
|||||||
Reference in New Issue
Block a user