security(gateway): isolate env/credential registries using ContextVars

This commit is contained in:
Dusk1e
2026-04-06 16:05:15 +03:00
committed by Teknium
parent da02a4e283
commit 7d0953d6ff
3 changed files with 47 additions and 11 deletions

View File

@@ -22,14 +22,26 @@ from __future__ import annotations
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Dict, List
logger = logging.getLogger(__name__)
# Session-scoped list of credential files to mount.
# Key: container_path (deduplicated), Value: host_path
_registered_files: Dict[str, str] = {}
# Backed by ContextVar to prevent cross-session data bleed in the gateway pipeline.
_registered_files_var: ContextVar[Dict[str, str]] = ContextVar("_registered_files")
def _get_registered() -> Dict[str, str]:
"""Get or create the registered credential files dict for the current context/session."""
try:
return _registered_files_var.get()
except LookupError:
val: Dict[str, str] = {}
_registered_files_var.set(val)
return val
# Cache for config-based file list (loaded once per process).
_config_files: List[Dict[str, str]] | None = None
@@ -86,7 +98,7 @@ def register_credential_file(
return False
container_path = f"{container_base.rstrip('/')}/{relative_path}"
_registered_files[container_path] = str(resolved)
_get_registered()[container_path] = str(resolved)
logger.debug("credential_files: registered %s -> %s", resolved, container_path)
return True
@@ -174,7 +186,7 @@ def get_credential_file_mounts() -> List[Dict[str, str]]:
mounts: Dict[str, str] = {}
# Skill-registered files
for container_path, host_path in _registered_files.items():
for container_path, host_path in _get_registered().items():
# Re-check existence (file may have been deleted since registration)
if Path(host_path).is_file():
mounts[container_path] = host_path
@@ -395,7 +407,7 @@ def iter_cache_files(
def clear_credential_files() -> None:
"""Reset the skill-scoped registry (e.g. on session reset)."""
_registered_files.clear()
_get_registered().clear()
def reset_config_cache() -> None:

View File

@@ -21,13 +21,25 @@ from __future__ import annotations
import logging
import os
from pathlib import Path
from contextvars import ContextVar
from typing import Iterable
logger = logging.getLogger(__name__)
# Session-scoped set of env var names that should pass through to sandboxes.
_allowed_env_vars: set[str] = set()
# Backed by ContextVar to prevent cross-session data bleed in the gateway pipeline.
_allowed_env_vars_var: ContextVar[set[str]] = ContextVar("_allowed_env_vars")
def _get_allowed() -> set[str]:
"""Get or create the allowed env vars set for the current context/session."""
try:
return _allowed_env_vars_var.get()
except LookupError:
val: set[str] = set()
_allowed_env_vars_var.set(val)
return val
# Cache for the config-based allowlist (loaded once per process).
_config_passthrough: frozenset[str] | None = None
@@ -41,7 +53,7 @@ def register_env_passthrough(var_names: Iterable[str]) -> None:
for name in var_names:
name = name.strip()
if name:
_allowed_env_vars.add(name)
_get_allowed().add(name)
logger.debug("env passthrough: registered %s", name)
@@ -78,19 +90,19 @@ def is_env_passthrough(var_name: str) -> bool:
Returns ``True`` if the variable was registered by a skill or listed in
the user's ``tools.env_passthrough`` config.
"""
if var_name in _allowed_env_vars:
if var_name in _get_allowed():
return True
return var_name in _load_config_passthrough()
def get_all_passthrough() -> frozenset[str]:
"""Return the union of skill-registered and config-based passthrough vars."""
return frozenset(_allowed_env_vars) | _load_config_passthrough()
return frozenset(_get_allowed()) | _load_config_passthrough()
def clear_env_passthrough() -> None:
"""Reset the skill-scoped allowlist (e.g. on session reset)."""
_allowed_env_vars.clear()
_get_allowed().clear()
def reset_config_cache() -> None: