#!/usr/bin/env python3 """ The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness. This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py), the body (Evennia/Morrowind), and the visualization surface. Security features: - Binds to 127.0.0.1 by default (localhost only) - Optional external binding via NEXUS_WS_HOST environment variable - Token-based authentication via NEXUS_WS_TOKEN environment variable - Rate limiting on connections - Connection logging and monitoring """ import asyncio import json import logging import os import signal import sys import time from typing import Set, Dict, Optional from collections import defaultdict # Branch protected file - see POLICY.md import websockets # Configuration PORT = int(os.environ.get("NEXUS_WS_PORT", "8765")) HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window # Logging setup logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger("nexus-gateway") # State clients: Set[websockets.WebSocketServerProtocol] = set() connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps] message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps] def check_rate_limit(ip: str) -> bool: """Check if IP has exceeded connection rate limit.""" now = time.time() # Clean old entries connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW] if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS: return False connection_tracker[ip].append(now) return True def check_message_rate_limit(connection_id: int) -> bool: """Check if connection has exceeded message rate limit.""" now = time.time() # Clean old entries message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW] if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES: return False message_tracker[connection_id].append(now) return True async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool: """Authenticate WebSocket connection using token.""" if not AUTH_TOKEN: # No authentication required return True try: # Wait for authentication message (first message should be auth) auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0) auth_data = json.loads(auth_message) if auth_data.get("type") != "auth": logger.warning(f"Invalid auth message type from {websocket.remote_address}") return False token = auth_data.get("token", "") if token != AUTH_TOKEN: logger.warning(f"Invalid auth token from {websocket.remote_address}") return False logger.info(f"Authenticated connection from {websocket.remote_address}") return True except asyncio.TimeoutError: logger.warning(f"Authentication timeout from {websocket.remote_address}") return False except json.JSONDecodeError: logger.warning(f"Invalid auth JSON from {websocket.remote_address}") return False except Exception as e: logger.error(f"Authentication error from {websocket.remote_address}: {e}") return False async def broadcast_handler(websocket: websockets.WebSocketServerProtocol): """Handles individual client connections and message broadcasting.""" addr = websocket.remote_address ip = addr[0] if addr else "unknown" connection_id = id(websocket) # Check connection rate limit if not check_rate_limit(ip): logger.warning(f"Connection rate limit exceeded for {ip}") await websocket.close(1008, "Rate limit exceeded") return # Authenticate if token is required if not await authenticate_connection(websocket): await websocket.close(1008, "Authentication failed") return clients.add(websocket) logger.info(f"Client connected from {addr}. Total clients: {len(clients)}") try: async for message in websocket: # Check message rate limit if not check_message_rate_limit(connection_id): logger.warning(f"Message rate limit exceeded for {addr}") await websocket.send(json.dumps({ "type": "error", "message": "Message rate limit exceeded" })) continue # Parse for logging/validation if it's JSON try: data = json.loads(message) msg_type = data.get("type", "unknown") # Optional: log specific important message types if msg_type in ["agent_register", "thought", "action"]: logger.debug(f"Received {msg_type} from {addr}") except (json.JSONDecodeError, TypeError): pass # Broadcast to all OTHER clients if not clients: continue disconnected = set() # Create broadcast tasks, tracking which client each task targets task_client_pairs = [] for client in clients: if client != websocket and client.open: task = asyncio.create_task(client.send(message)) task_client_pairs.append((task, client)) if task_client_pairs: tasks = [pair[0] for pair in task_client_pairs] results = await asyncio.gather(*tasks, return_exceptions=True) for i, result in enumerate(results): if isinstance(result, Exception): target_client = task_client_pairs[i][1] logger.error(f"Failed to send to client {target_client.remote_address}: {result}") disconnected.add(target_client) if disconnected: clients.difference_update(disconnected) except websockets.exceptions.ConnectionClosed: logger.debug(f"Connection closed by client {addr}") except Exception as e: logger.error(f"Error handling client {addr}: {e}") finally: clients.discard(websocket) logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}") async def main(): """Main server loop with graceful shutdown.""" # Log security configuration if AUTH_TOKEN: logger.info("Authentication: ENABLED (token required)") else: logger.warning("Authentication: DISABLED (no token required)") if HOST == "0.0.0.0": logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK") else: logger.info(f"Host binding: {HOST} (localhost only)") logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, " f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s") logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}") # Set up signal handlers for graceful shutdown loop = asyncio.get_running_loop() stop = loop.create_future() def shutdown(): if not stop.done(): stop.set_result(None) for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, shutdown) except NotImplementedError: # Signal handlers not supported on Windows pass async with websockets.serve(broadcast_handler, HOST, PORT): logger.info("Gateway is ready and listening.") await stop logger.info("Shutting down Nexus WS gateway...") # Close any remaining client connections (handlers may have already cleaned up) remaining = {c for c in clients if c.open} if remaining: logger.info(f"Closing {len(remaining)} active connections...") close_tasks = [client.close() for client in remaining] await asyncio.gather(*close_tasks, return_exceptions=True) clients.clear() logger.info("Shutdown complete.") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: pass except Exception as e: logger.critical(f"Fatal server error: {e}") sys.exit(1)