[kimi] Add WebSocket authentication for Matrix connections (#682) #744
@@ -155,6 +155,12 @@ class Settings(BaseSettings):
|
||||
# Example: "http://100.124.176.28:8080" or "https://alexanderwhitestone.com"
|
||||
matrix_frontend_url: str = "" # Empty = disabled
|
||||
|
||||
# WebSocket authentication token for Matrix connections.
|
||||
# When set, clients must provide this token via ?token= query param
|
||||
# or in the first message as {"type": "auth", "token": "..."}.
|
||||
# Empty/unset = auth disabled (dev mode).
|
||||
matrix_ws_token: str = ""
|
||||
|
||||
# Trusted hosts for the Host header check (TrustedHostMiddleware).
|
||||
# Set TRUSTED_HOSTS as a comma-separated list. Wildcards supported (e.g. "*.ts.net").
|
||||
# Defaults include localhost + Tailscale MagicDNS. Add your Tailscale IP if needed.
|
||||
|
||||
@@ -415,6 +415,50 @@ async def _heartbeat(websocket: WebSocket) -> None:
|
||||
logger.debug("Heartbeat stopped — connection gone")
|
||||
|
||||
|
||||
async def _authenticate_ws(websocket: WebSocket) -> bool:
|
||||
"""Authenticate WebSocket connection using matrix_ws_token.
|
||||
|
||||
Checks for token in query param ?token= first. If no query param,
|
||||
accepts the connection and waits for first message with
|
||||
{"type": "auth", "token": "..."}.
|
||||
|
||||
Returns True if authenticated (or if auth is disabled).
|
||||
Returns False and closes connection with code 4001 if invalid.
|
||||
"""
|
||||
token_setting = settings.matrix_ws_token
|
||||
|
||||
# Auth disabled in dev mode (empty/unset token)
|
||||
if not token_setting:
|
||||
return True
|
||||
|
||||
# Check query param first (can validate before accept)
|
||||
query_token = websocket.query_params.get("token", "")
|
||||
if query_token:
|
||||
if query_token == token_setting:
|
||||
return True
|
||||
# Invalid token in query param - we need to accept to close properly
|
||||
await websocket.accept()
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return False
|
||||
|
||||
# No query token - accept and wait for auth message
|
||||
await websocket.accept()
|
||||
|
||||
# Wait for auth message as first message
|
||||
try:
|
||||
raw = await websocket.receive_text()
|
||||
data = json.loads(raw)
|
||||
if data.get("type") == "auth" and data.get("token") == token_setting:
|
||||
return True
|
||||
# Invalid auth message
|
||||
await websocket.close(code=4001, reason="Invalid token")
|
||||
return False
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Non-JSON first message without valid token
|
||||
await websocket.close(code=4001, reason="Authentication required")
|
||||
return False
|
||||
|
||||
|
||||
@router.websocket("/ws")
|
||||
async def world_ws(websocket: WebSocket) -> None:
|
||||
"""Accept a Workshop client and keep it alive for state broadcasts.
|
||||
@@ -423,8 +467,28 @@ async def world_ws(websocket: WebSocket) -> None:
|
||||
client never starts from a blank slate. Incoming frames are parsed
|
||||
as JSON — ``visitor_message`` triggers a bark response. A background
|
||||
heartbeat ping runs every 15 s to detect dead connections early.
|
||||
|
||||
Authentication:
|
||||
- If matrix_ws_token is configured, clients must provide it via
|
||||
?token= query param or in the first message as
|
||||
{"type": "auth", "token": "..."}.
|
||||
- Invalid token results in close code 4001.
|
||||
- Valid token receives a connection_ack message.
|
||||
"""
|
||||
await websocket.accept()
|
||||
# Authenticate (may accept connection internally)
|
||||
is_authed = await _authenticate_ws(websocket)
|
||||
if not is_authed:
|
||||
logger.info("World WS connection rejected — invalid token")
|
||||
return
|
||||
|
||||
# Auth passed - accept if not already accepted
|
||||
if websocket.client_state.name != "CONNECTED":
|
||||
await websocket.accept()
|
||||
|
||||
# Send connection_ack if auth was required
|
||||
if settings.matrix_ws_token:
|
||||
await websocket.send_text(json.dumps({"type": "connection_ack"}))
|
||||
|
||||
_ws_clients.append(websocket)
|
||||
logger.info("World WS connected — %d clients", len(_ws_clients))
|
||||
|
||||
|
||||
@@ -246,6 +246,131 @@ def test_world_ws_endpoint_accepts_connection(client):
|
||||
pass # Connection accepted — just close it
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WebSocket Authentication Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestWebSocketAuth:
|
||||
"""Tests for WebSocket token-based authentication."""
|
||||
|
||||
def test_ws_auth_disabled_when_token_unset(self, client):
|
||||
"""When matrix_ws_token is empty, auth is disabled (dev mode)."""
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = ""
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Should receive world_state without auth
|
||||
msg = json.loads(ws.receive_text())
|
||||
assert msg["type"] == "world_state"
|
||||
|
||||
def test_ws_valid_token_via_query_param(self, client):
|
||||
"""Valid token via ?token= query param allows connection."""
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
with client.websocket_connect("/api/world/ws?token=secret123") as ws:
|
||||
# Should receive connection_ack first
|
||||
ack = json.loads(ws.receive_text())
|
||||
assert ack["type"] == "connection_ack"
|
||||
# Then world_state
|
||||
msg = json.loads(ws.receive_text())
|
||||
assert msg["type"] == "world_state"
|
||||
|
||||
def test_ws_valid_token_via_first_message(self, client):
|
||||
"""Valid token via first auth message allows connection."""
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Send auth message
|
||||
ws.send_text(json.dumps({"type": "auth", "token": "secret123"}))
|
||||
# Should receive connection_ack
|
||||
ack = json.loads(ws.receive_text())
|
||||
assert ack["type"] == "connection_ack"
|
||||
# Then world_state
|
||||
msg = json.loads(ws.receive_text())
|
||||
assert msg["type"] == "world_state"
|
||||
|
||||
def test_ws_invalid_token_via_query_param(self, client):
|
||||
"""Invalid token via ?token= closes connection with code 4001."""
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
# When auth fails with query param, accept() is called then close()
|
||||
# The test client raises WebSocketDisconnect on close
|
||||
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||||
with client.websocket_connect("/api/world/ws?token=wrongtoken") as ws:
|
||||
# Try to receive - should trigger the close
|
||||
ws.receive_text()
|
||||
assert exc_info.value.code == 4001
|
||||
|
||||
def test_ws_invalid_token_via_first_message(self, client):
|
||||
"""Invalid token via first message closes connection with code 4001."""
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Send invalid auth message
|
||||
ws.send_text(json.dumps({"type": "auth", "token": "wrongtoken"}))
|
||||
# Connection should close with 4001
|
||||
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||||
ws.receive_text()
|
||||
assert exc_info.value.code == 4001
|
||||
|
||||
def test_ws_no_token_when_auth_required(self, client):
|
||||
"""No token when auth is required closes connection with code 4001."""
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Send non-auth message without token
|
||||
ws.send_text(json.dumps({"type": "visitor_message", "text": "hello"}))
|
||||
# Connection should close with 4001
|
||||
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||||
ws.receive_text()
|
||||
assert exc_info.value.code == 4001
|
||||
|
||||
def test_ws_non_json_first_message_when_auth_required(self, client):
|
||||
"""Non-JSON first message when auth required closes with 4001."""
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "secret123"
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Send non-JSON message
|
||||
ws.send_text("not json")
|
||||
# Connection should close with 4001
|
||||
with pytest.raises(WebSocketDisconnect) as exc_info:
|
||||
ws.receive_text()
|
||||
assert exc_info.value.code == 4001
|
||||
|
||||
def test_ws_existing_behavior_unchanged_when_token_not_configured(self, client, tmp_path):
|
||||
"""Existing /api/world/ws behavior unchanged when token not configured."""
|
||||
f = tmp_path / "presence.json"
|
||||
f.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"version": 1,
|
||||
"liveness": "2026-03-19T02:00:00Z",
|
||||
"mood": "exploring",
|
||||
"current_focus": "testing",
|
||||
"active_threads": [],
|
||||
"recent_events": [],
|
||||
"concerns": [],
|
||||
}
|
||||
)
|
||||
)
|
||||
with patch("dashboard.routes.world.settings") as mock_settings:
|
||||
mock_settings.matrix_ws_token = "" # Not configured
|
||||
with patch("dashboard.routes.world.PRESENCE_FILE", f):
|
||||
with client.websocket_connect("/api/world/ws") as ws:
|
||||
# Should receive world_state directly (no connection_ack)
|
||||
msg = json.loads(ws.receive_text())
|
||||
assert msg["type"] == "world_state"
|
||||
assert msg["timmyState"]["mood"] == "exploring"
|
||||
|
||||
|
||||
def test_world_ws_sends_snapshot_on_connect(client, tmp_path):
|
||||
"""WebSocket sends a world_state snapshot immediately on connect."""
|
||||
f = tmp_path / "presence.json"
|
||||
|
||||
Reference in New Issue
Block a user