189 lines
6.1 KiB
Python
189 lines
6.1 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
|
|
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
|
|
the body (Evennia/Morrowind), and the visualization surface.
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
from dataclasses import dataclass
|
|
from typing import Set
|
|
from urllib.parse import parse_qs, urlparse
|
|
|
|
from websockets.asyncio.server import serve
|
|
from websockets.datastructures import Headers
|
|
from websockets.exceptions import ConnectionClosed
|
|
from websockets.http11 import Request, Response
|
|
|
|
|
|
LOCAL_ONLY_HOSTS = {"127.0.0.1", "localhost", "::1"}
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class GatewayConfig:
|
|
host: str
|
|
port: int
|
|
auth_token: str | None = None
|
|
require_auth: bool = False
|
|
|
|
|
|
# Logging setup
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger("nexus-gateway")
|
|
|
|
|
|
# State
|
|
clients: Set[object] = set()
|
|
|
|
|
|
def _is_local_host(host: str) -> bool:
|
|
return host in LOCAL_ONLY_HOSTS or host.startswith("127.")
|
|
|
|
|
|
def get_gateway_config(env: dict | None = None) -> GatewayConfig:
|
|
env = env or os.environ
|
|
host = str(env.get("NEXUS_WS_HOST", "127.0.0.1")).strip() or "127.0.0.1"
|
|
port = int(str(env.get("NEXUS_WS_PORT", "8765")))
|
|
auth_token = str(env.get("NEXUS_WS_AUTH_TOKEN", "")).strip() or None
|
|
require_auth = not _is_local_host(host)
|
|
if require_auth and not auth_token:
|
|
raise ValueError("NEXUS_WS_AUTH_TOKEN is required when NEXUS_WS_HOST is non-local")
|
|
return GatewayConfig(host=host, port=port, auth_token=auth_token, require_auth=require_auth)
|
|
|
|
|
|
def _extract_request_token(request: Request) -> str | None:
|
|
auth_header = request.headers.get("Authorization")
|
|
if auth_header and auth_header.lower().startswith("bearer "):
|
|
return auth_header.split(" ", 1)[1].strip() or None
|
|
|
|
query = parse_qs(urlparse(request.path).query)
|
|
for key in ("ws_token", "token"):
|
|
values = query.get(key)
|
|
if values:
|
|
token = values[0].strip()
|
|
if token:
|
|
return token
|
|
return None
|
|
|
|
|
|
def _unauthorized_response(message: str) -> Response:
|
|
headers = Headers({"Content-Type": "text/plain; charset=utf-8"})
|
|
return Response(401, "Unauthorized", headers, message.encode("utf-8"))
|
|
|
|
|
|
def make_process_request(config: GatewayConfig):
|
|
def process_request(_connection, request: Request):
|
|
if not config.require_auth:
|
|
return None
|
|
token = _extract_request_token(request)
|
|
if token != config.auth_token:
|
|
return _unauthorized_response("Missing or invalid websocket auth token")
|
|
return None
|
|
|
|
return process_request
|
|
|
|
|
|
async def broadcast_handler(websocket):
|
|
"""Handles individual client connections and message broadcasting."""
|
|
clients.add(websocket)
|
|
addr = getattr(websocket, "remote_address", None)
|
|
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
|
|
|
|
try:
|
|
async for message in websocket:
|
|
# Parse for logging/validation if it's JSON
|
|
try:
|
|
data = json.loads(message)
|
|
msg_type = data.get("type", "unknown")
|
|
if msg_type in ["agent_register", "thought", "action"]:
|
|
logger.debug(f"Received {msg_type} from {addr}")
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
disconnected = set()
|
|
task_client_pairs = []
|
|
for client in list(clients):
|
|
if client is websocket:
|
|
continue
|
|
task = asyncio.create_task(client.send(message))
|
|
task_client_pairs.append((task, client))
|
|
|
|
if task_client_pairs:
|
|
tasks = [pair[0] for pair in task_client_pairs]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
target_client = task_client_pairs[i][1]
|
|
target_addr = getattr(target_client, "remote_address", None)
|
|
logger.error(f"Failed to send to client {target_addr}: {result}")
|
|
disconnected.add(target_client)
|
|
|
|
if disconnected:
|
|
clients.difference_update(disconnected)
|
|
|
|
except ConnectionClosed:
|
|
logger.debug(f"Connection closed by client {addr}")
|
|
except Exception as e:
|
|
logger.error(f"Error handling client {addr}: {e}")
|
|
finally:
|
|
clients.discard(websocket)
|
|
logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}")
|
|
|
|
|
|
async def main():
|
|
"""Main server loop with graceful shutdown."""
|
|
config = get_gateway_config()
|
|
logger.info(f"Starting Nexus WS gateway on ws://{config.host}:{config.port}")
|
|
if config.require_auth:
|
|
logger.info("Remote gateway mode enabled — websocket auth token required")
|
|
|
|
# Set up signal handlers for graceful shutdown
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
def shutdown():
|
|
if not stop.done():
|
|
stop.set_result(None)
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, shutdown)
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
async with serve(
|
|
broadcast_handler,
|
|
config.host,
|
|
config.port,
|
|
process_request=make_process_request(config),
|
|
):
|
|
logger.info("Gateway is ready and listening.")
|
|
await stop
|
|
|
|
logger.info("Shutting down Nexus WS gateway...")
|
|
remaining = set(clients)
|
|
if remaining:
|
|
logger.info(f"Closing {len(remaining)} active connections...")
|
|
close_tasks = [client.close() for client in remaining]
|
|
await asyncio.gather(*close_tasks, return_exceptions=True)
|
|
clients.clear()
|
|
logger.info("Shutdown complete.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
logger.critical(f"Fatal server error: {e}")
|
|
sys.exit(1)
|