fix(security): block untrusted browser access to api server (#2451)
Co-authored-by: ifrederico <fr@tecompanytea.com>
This commit is contained in:
@@ -119,22 +119,33 @@ class TestAdapterInit:
|
||||
def test_custom_config_from_extra(self):
|
||||
config = PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0", "port": 9999, "key": "sk-test"},
|
||||
extra={
|
||||
"host": "0.0.0.0",
|
||||
"port": 9999,
|
||||
"key": "sk-test",
|
||||
"cors_origins": ["http://localhost:3000"],
|
||||
},
|
||||
)
|
||||
adapter = APIServerAdapter(config)
|
||||
assert adapter._host == "0.0.0.0"
|
||||
assert adapter._port == 9999
|
||||
assert adapter._api_key == "sk-test"
|
||||
assert adapter._cors_origins == ("http://localhost:3000",)
|
||||
|
||||
def test_config_from_env(self, monkeypatch):
|
||||
monkeypatch.setenv("API_SERVER_HOST", "10.0.0.1")
|
||||
monkeypatch.setenv("API_SERVER_PORT", "7777")
|
||||
monkeypatch.setenv("API_SERVER_KEY", "sk-env")
|
||||
monkeypatch.setenv("API_SERVER_CORS_ORIGINS", "http://localhost:3000, http://127.0.0.1:3000")
|
||||
config = PlatformConfig(enabled=True)
|
||||
adapter = APIServerAdapter(config)
|
||||
assert adapter._host == "10.0.0.1"
|
||||
assert adapter._port == 7777
|
||||
assert adapter._api_key == "sk-env"
|
||||
assert adapter._cors_origins == (
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -190,11 +201,13 @@ class TestAuth:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
def _make_adapter(api_key: str = "", cors_origins=None) -> APIServerAdapter:
|
||||
"""Create an adapter with optional API key."""
|
||||
extra = {}
|
||||
if api_key:
|
||||
extra["key"] = api_key
|
||||
if cors_origins is not None:
|
||||
extra["cors_origins"] = cors_origins
|
||||
config = PlatformConfig(enabled=True, extra=extra)
|
||||
return APIServerAdapter(config)
|
||||
|
||||
@@ -202,6 +215,7 @@ def _make_adapter(api_key: str = "") -> APIServerAdapter:
|
||||
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
||||
"""Create the aiohttp app from the adapter (without starting the full server)."""
|
||||
app = web.Application(middlewares=[cors_middleware])
|
||||
app["api_server_adapter"] = adapter
|
||||
app.router.add_get("/health", adapter._handle_health)
|
||||
app.router.add_get("/v1/models", adapter._handle_models)
|
||||
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
|
||||
@@ -788,6 +802,19 @@ class TestConfigIntegration:
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("port") == 9999
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("host") == "0.0.0.0"
|
||||
|
||||
def test_env_override_cors_origins(self, monkeypatch):
|
||||
monkeypatch.setenv("API_SERVER_ENABLED", "true")
|
||||
monkeypatch.setenv(
|
||||
"API_SERVER_CORS_ORIGINS",
|
||||
"http://localhost:3000, http://127.0.0.1:3000",
|
||||
)
|
||||
from gateway.config import load_gateway_config
|
||||
config = load_gateway_config()
|
||||
assert config.platforms[Platform.API_SERVER].extra.get("cors_origins") == [
|
||||
"http://localhost:3000",
|
||||
"http://127.0.0.1:3000",
|
||||
]
|
||||
|
||||
def test_api_server_in_connected_platforms(self):
|
||||
config = GatewayConfig()
|
||||
config.platforms[Platform.API_SERVER] = PlatformConfig(enabled=True)
|
||||
@@ -1156,26 +1183,91 @@ class TestTruncation:
|
||||
|
||||
|
||||
class TestCORS:
|
||||
def test_origin_allowed_for_non_browser_client(self, adapter):
|
||||
assert adapter._origin_allowed("") is True
|
||||
|
||||
def test_origin_rejected_by_default(self, adapter):
|
||||
assert adapter._origin_allowed("http://evil.example") is False
|
||||
|
||||
def test_origin_allowed_for_allowlist_match(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
assert adapter._origin_allowed("http://localhost:3000") is True
|
||||
|
||||
def test_cors_headers_for_origin_disabled_by_default(self, adapter):
|
||||
assert adapter._cors_headers_for_origin("http://localhost:3000") is None
|
||||
|
||||
def test_cors_headers_for_origin_matches_allowlist(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
headers = adapter._cors_headers_for_origin("http://localhost:3000")
|
||||
assert headers is not None
|
||||
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
|
||||
assert "POST" in headers["Access-Control-Allow-Methods"]
|
||||
|
||||
def test_cors_headers_for_origin_rejects_unknown_origin(self):
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
assert adapter._cors_headers_for_origin("http://evil.example") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_headers_on_get(self, adapter):
|
||||
"""CORS headers present on normal responses."""
|
||||
async def test_cors_headers_not_present_by_default(self, adapter):
|
||||
"""CORS is disabled unless explicitly configured."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health")
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_browser_origin_rejected_by_default(self, adapter):
|
||||
"""Browser-originated requests are rejected unless explicitly allowed."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health", headers={"Origin": "http://evil.example"})
|
||||
assert resp.status == 403
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_options_preflight_rejected_by_default(self, adapter):
|
||||
"""Browser preflight is rejected unless CORS is explicitly configured."""
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.options(
|
||||
"/v1/chat/completions",
|
||||
headers={
|
||||
"Origin": "http://evil.example",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
},
|
||||
)
|
||||
assert resp.status == 403
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_headers_present_for_allowed_origin(self):
|
||||
"""Allowed origins receive explicit CORS headers."""
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"})
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||
assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_options_preflight(self, adapter):
|
||||
"""OPTIONS preflight request returns CORS headers."""
|
||||
async def test_cors_options_preflight_allowed_for_configured_origin(self):
|
||||
"""Configured origins can complete browser preflight."""
|
||||
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
# OPTIONS to a known path — aiohttp will route through middleware
|
||||
resp = await cli.options("/health")
|
||||
resp = await cli.options(
|
||||
"/v1/chat/completions",
|
||||
headers={
|
||||
"Origin": "http://localhost:3000",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "Authorization, Content-Type",
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "*"
|
||||
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||
assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user