fix(security): block untrusted browser access to api server (#2451)

Co-authored-by: ifrederico <fr@tecompanytea.com>
This commit is contained in:
Teknium
2026-03-22 04:08:48 -07:00
committed by GitHub
parent b81926def6
commit e109a8b502
6 changed files with 196 additions and 33 deletions

View File

@@ -119,22 +119,33 @@ class TestAdapterInit:
def test_custom_config_from_extra(self):
config = PlatformConfig(
enabled=True,
extra={"host": "0.0.0.0", "port": 9999, "key": "sk-test"},
extra={
"host": "0.0.0.0",
"port": 9999,
"key": "sk-test",
"cors_origins": ["http://localhost:3000"],
},
)
adapter = APIServerAdapter(config)
assert adapter._host == "0.0.0.0"
assert adapter._port == 9999
assert adapter._api_key == "sk-test"
assert adapter._cors_origins == ("http://localhost:3000",)
def test_config_from_env(self, monkeypatch):
monkeypatch.setenv("API_SERVER_HOST", "10.0.0.1")
monkeypatch.setenv("API_SERVER_PORT", "7777")
monkeypatch.setenv("API_SERVER_KEY", "sk-env")
monkeypatch.setenv("API_SERVER_CORS_ORIGINS", "http://localhost:3000, http://127.0.0.1:3000")
config = PlatformConfig(enabled=True)
adapter = APIServerAdapter(config)
assert adapter._host == "10.0.0.1"
assert adapter._port == 7777
assert adapter._api_key == "sk-env"
assert adapter._cors_origins == (
"http://localhost:3000",
"http://127.0.0.1:3000",
)
# ---------------------------------------------------------------------------
@@ -190,11 +201,13 @@ class TestAuth:
# ---------------------------------------------------------------------------
def _make_adapter(api_key: str = "") -> APIServerAdapter:
def _make_adapter(api_key: str = "", cors_origins=None) -> APIServerAdapter:
"""Create an adapter with optional API key."""
extra = {}
if api_key:
extra["key"] = api_key
if cors_origins is not None:
extra["cors_origins"] = cors_origins
config = PlatformConfig(enabled=True, extra=extra)
return APIServerAdapter(config)
@@ -202,6 +215,7 @@ def _make_adapter(api_key: str = "") -> APIServerAdapter:
def _create_app(adapter: APIServerAdapter) -> web.Application:
"""Create the aiohttp app from the adapter (without starting the full server)."""
app = web.Application(middlewares=[cors_middleware])
app["api_server_adapter"] = adapter
app.router.add_get("/health", adapter._handle_health)
app.router.add_get("/v1/models", adapter._handle_models)
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
@@ -788,6 +802,19 @@ class TestConfigIntegration:
assert config.platforms[Platform.API_SERVER].extra.get("port") == 9999
assert config.platforms[Platform.API_SERVER].extra.get("host") == "0.0.0.0"
def test_env_override_cors_origins(self, monkeypatch):
monkeypatch.setenv("API_SERVER_ENABLED", "true")
monkeypatch.setenv(
"API_SERVER_CORS_ORIGINS",
"http://localhost:3000, http://127.0.0.1:3000",
)
from gateway.config import load_gateway_config
config = load_gateway_config()
assert config.platforms[Platform.API_SERVER].extra.get("cors_origins") == [
"http://localhost:3000",
"http://127.0.0.1:3000",
]
def test_api_server_in_connected_platforms(self):
config = GatewayConfig()
config.platforms[Platform.API_SERVER] = PlatformConfig(enabled=True)
@@ -1156,26 +1183,91 @@ class TestTruncation:
class TestCORS:
def test_origin_allowed_for_non_browser_client(self, adapter):
assert adapter._origin_allowed("") is True
def test_origin_rejected_by_default(self, adapter):
assert adapter._origin_allowed("http://evil.example") is False
def test_origin_allowed_for_allowlist_match(self):
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
assert adapter._origin_allowed("http://localhost:3000") is True
def test_cors_headers_for_origin_disabled_by_default(self, adapter):
assert adapter._cors_headers_for_origin("http://localhost:3000") is None
def test_cors_headers_for_origin_matches_allowlist(self):
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
headers = adapter._cors_headers_for_origin("http://localhost:3000")
assert headers is not None
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
assert "POST" in headers["Access-Control-Allow-Methods"]
def test_cors_headers_for_origin_rejects_unknown_origin(self):
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
assert adapter._cors_headers_for_origin("http://evil.example") is None
@pytest.mark.asyncio
async def test_cors_headers_on_get(self, adapter):
"""CORS headers present on normal responses."""
async def test_cors_headers_not_present_by_default(self, adapter):
"""CORS is disabled unless explicitly configured."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.get("/health")
assert resp.status == 200
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
assert resp.headers.get("Access-Control-Allow-Origin") is None
@pytest.mark.asyncio
async def test_browser_origin_rejected_by_default(self, adapter):
"""Browser-originated requests are rejected unless explicitly allowed."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.get("/health", headers={"Origin": "http://evil.example"})
assert resp.status == 403
assert resp.headers.get("Access-Control-Allow-Origin") is None
@pytest.mark.asyncio
async def test_cors_options_preflight_rejected_by_default(self, adapter):
"""Browser preflight is rejected unless CORS is explicitly configured."""
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.options(
"/v1/chat/completions",
headers={
"Origin": "http://evil.example",
"Access-Control-Request-Method": "POST",
},
)
assert resp.status == 403
assert resp.headers.get("Access-Control-Allow-Origin") is None
@pytest.mark.asyncio
async def test_cors_headers_present_for_allowed_origin(self):
"""Allowed origins receive explicit CORS headers."""
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"})
assert resp.status == 200
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "")
assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "")
@pytest.mark.asyncio
async def test_cors_options_preflight(self, adapter):
"""OPTIONS preflight request returns CORS headers."""
async def test_cors_options_preflight_allowed_for_configured_origin(self):
"""Configured origins can complete browser preflight."""
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
app = _create_app(adapter)
async with TestClient(TestServer(app)) as cli:
# OPTIONS to a known path — aiohttp will route through middleware
resp = await cli.options("/health")
resp = await cli.options(
"/v1/chat/completions",
headers={
"Origin": "http://localhost:3000",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "Authorization, Content-Type",
},
)
assert resp.status == 200
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "")