wip: default websocket gateway to localhost and gate remote access

This commit is contained in:
Alexander Whitestone
2026-04-15 04:23:17 -04:00
parent 0709859787
commit adb1bae69d

131
server.py
View File

@@ -7,16 +7,29 @@ the body (Evennia/Morrowind), and the visualization surface.
import asyncio
import json
import logging
import os
import signal
import sys
from dataclasses import dataclass
from typing import Set
from urllib.parse import parse_qs, urlparse
# Branch protected file - see POLICY.md
import websockets
from websockets.asyncio.server import serve
from websockets.datastructures import Headers
from websockets.exceptions import ConnectionClosed
from websockets.http11 import Request, Response
LOCAL_ONLY_HOSTS = {"127.0.0.1", "localhost", "::1"}
@dataclass(frozen=True)
class GatewayConfig:
host: str
port: int
auth_token: str | None = None
require_auth: bool = False
# Configuration
PORT = 8765
HOST = "0.0.0.0" # Allow external connections if needed
# Logging setup
logging.basicConfig(
@@ -26,38 +39,82 @@ logging.basicConfig(
)
logger = logging.getLogger("nexus-gateway")
# State
clients: Set[websockets.WebSocketServerProtocol] = set()
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
# State
clients: Set[object] = set()
def _is_local_host(host: str) -> bool:
return host in LOCAL_ONLY_HOSTS or host.startswith("127.")
def get_gateway_config(env: dict | None = None) -> GatewayConfig:
env = env or os.environ
host = str(env.get("NEXUS_WS_HOST", "127.0.0.1")).strip() or "127.0.0.1"
port = int(str(env.get("NEXUS_WS_PORT", "8765")))
auth_token = str(env.get("NEXUS_WS_AUTH_TOKEN", "")).strip() or None
require_auth = not _is_local_host(host)
if require_auth and not auth_token:
raise ValueError("NEXUS_WS_AUTH_TOKEN is required when NEXUS_WS_HOST is non-local")
return GatewayConfig(host=host, port=port, auth_token=auth_token, require_auth=require_auth)
def _extract_request_token(request: Request) -> str | None:
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.lower().startswith("bearer "):
return auth_header.split(" ", 1)[1].strip() or None
query = parse_qs(urlparse(request.path).query)
for key in ("ws_token", "token"):
values = query.get(key)
if values:
token = values[0].strip()
if token:
return token
return None
def _unauthorized_response(message: str) -> Response:
headers = Headers({"Content-Type": "text/plain; charset=utf-8"})
return Response(401, "Unauthorized", headers, message.encode("utf-8"))
def make_process_request(config: GatewayConfig):
def process_request(_connection, request: Request):
if not config.require_auth:
return None
token = _extract_request_token(request)
if token != config.auth_token:
return _unauthorized_response("Missing or invalid websocket auth token")
return None
return process_request
async def broadcast_handler(websocket):
"""Handles individual client connections and message broadcasting."""
clients.add(websocket)
addr = websocket.remote_address
addr = getattr(websocket, "remote_address", None)
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
try:
async for message in websocket:
# Parse for logging/validation if it's JSON
try:
data = json.loads(message)
msg_type = data.get("type", "unknown")
# Optional: log specific important message types
if msg_type in ["agent_register", "thought", "action"]:
logger.debug(f"Received {msg_type} from {addr}")
except (json.JSONDecodeError, TypeError):
pass
# Broadcast to all OTHER clients
if not clients:
continue
disconnected = set()
# Create broadcast tasks, tracking which client each task targets
task_client_pairs = []
for client in clients:
if client != websocket and client.open:
task = asyncio.create_task(client.send(message))
task_client_pairs.append((task, client))
for client in list(clients):
if client is websocket:
continue
task = asyncio.create_task(client.send(message))
task_client_pairs.append((task, client))
if task_client_pairs:
tasks = [pair[0] for pair in task_client_pairs]
@@ -65,13 +122,14 @@ async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
for i, result in enumerate(results):
if isinstance(result, Exception):
target_client = task_client_pairs[i][1]
logger.error(f"Failed to send to client {target_client.remote_address}: {result}")
target_addr = getattr(target_client, "remote_address", None)
logger.error(f"Failed to send to client {target_addr}: {result}")
disconnected.add(target_client)
if disconnected:
clients.difference_update(disconnected)
except websockets.exceptions.ConnectionClosed:
except ConnectionClosed:
logger.debug(f"Connection closed by client {addr}")
except Exception as e:
logger.error(f"Error handling client {addr}: {e}")
@@ -79,14 +137,18 @@ async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
clients.discard(websocket)
logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}")
async def main():
"""Main server loop with graceful shutdown."""
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
config = get_gateway_config()
logger.info(f"Starting Nexus WS gateway on ws://{config.host}:{config.port}")
if config.require_auth:
logger.info("Remote gateway mode enabled — websocket auth token required")
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
stop = loop.create_future()
def shutdown():
if not stop.done():
stop.set_result(None)
@@ -95,24 +157,27 @@ async def main():
try:
loop.add_signal_handler(sig, shutdown)
except NotImplementedError:
# Signal handlers not supported on Windows
pass
async with websockets.serve(broadcast_handler, HOST, PORT):
async with serve(
broadcast_handler,
config.host,
config.port,
process_request=make_process_request(config),
):
logger.info("Gateway is ready and listening.")
await stop
logger.info("Shutting down Nexus WS gateway...")
# Close any remaining client connections (handlers may have already cleaned up)
remaining = {c for c in clients if c.open}
remaining = set(clients)
if remaining:
logger.info(f"Closing {len(remaining)} active connections...")
close_tasks = [client.close() for client in remaining]
await asyncio.gather(*close_tasks, return_exceptions=True)
clients.clear()
logger.info("Shutdown complete.")
if __name__ == "__main__":
try:
asyncio.run(main())