#!/usr/bin/env python3 """ The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness. This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py), the body (Evennia/Morrowind), and the visualization surface. Security features: - Binds to 127.0.0.1 by default (localhost only) - Optional external binding via NEXUS_WS_HOST environment variable - Token-based authentication via NEXUS_WS_TOKEN environment variable - Rate limiting on connections - Connection logging and monitoring """ import asyncio import json import logging import os import pty import signal import subprocess import sys import time from typing import Set, Dict, Optional from collections import defaultdict # Branch protected file - see POLICY.md import websockets # Configuration PORT = int(os.environ.get("NEXUS_WS_PORT", "8765")) PTY_PORT = int(os.environ.get("NEXUS_PTY_PORT", "8766")) # operator shell PTY gateway HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required RATE_LIMIT_WINDOW = 60 # seconds RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window # Logging setup logging.basicConfig( level=logging.INFO, format='%(asctime)s [%(levelname)s] %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) logger = logging.getLogger("nexus-gateway") # State clients: Set[websockets.WebSocketServerProtocol] = set() connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps] message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps] def check_rate_limit(ip: str) -> bool: """Check if IP has exceeded connection rate limit.""" now = time.time() # Clean old entries connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW] if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS: return False connection_tracker[ip].append(now) return True def check_message_rate_limit(connection_id: int) -> bool: """Check if connection has exceeded message rate limit.""" now = time.time() # Clean old entries message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW] if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES: return False message_tracker[connection_id].append(now) return True async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool: """Authenticate WebSocket connection using token.""" if not AUTH_TOKEN: # No authentication required return True try: # Wait for authentication message (first message should be auth) auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0) auth_data = json.loads(auth_message) if auth_data.get("type") != "auth": logger.warning(f"Invalid auth message type from {websocket.remote_address}") return False token = auth_data.get("token", "") if token != AUTH_TOKEN: logger.warning(f"Invalid auth token from {websocket.remote_address}") return False logger.info(f"Authenticated connection from {websocket.remote_address}") return True except asyncio.TimeoutError: logger.warning(f"Authentication timeout from {websocket.remote_address}") return False except json.JSONDecodeError: logger.warning(f"Invalid auth JSON from {websocket.remote_address}") return False except Exception as e: logger.error(f"Authentication error from {websocket.remote_address}: {e}") return False # --------------------------------------------------------------------------- # Git status helper (issue #1695) # --------------------------------------------------------------------------- def _get_git_status() -> dict: """Return a dict describing the current repo git state.""" repo_root = os.path.dirname(os.path.abspath(__file__)) try: branch_out = subprocess.check_output( ["git", "rev-parse", "--abbrev-ref", "HEAD"], cwd=repo_root, stderr=subprocess.DEVNULL, text=True ).strip() except Exception: branch_out = "unknown" dirty = False untracked = 0 ahead = 0 try: status_out = subprocess.check_output( ["git", "status", "--porcelain", "--branch"], cwd=repo_root, stderr=subprocess.DEVNULL, text=True ) for line in status_out.splitlines(): if line.startswith("##") and "ahead" in line: import re m = re.search(r"ahead (\d+)", line) if m: ahead = int(m.group(1)) elif line.startswith("??"): untracked += 1 elif line and not line.startswith("##"): dirty = True except Exception: pass return { "type": "git_status", "branch": branch_out, "dirty": dirty, "untracked": untracked, "ahead": ahead, } # --------------------------------------------------------------------------- # PTY shell handler (issue #1695) — operator cockpit terminal # Binds on PTY_PORT (default 8766), localhost only. # Each WebSocket connection gets its own PTY subprocess. # --------------------------------------------------------------------------- async def pty_handler(websocket: websockets.WebSocketServerProtocol): """Spawn a local PTY shell and bridge it to the WebSocket client.""" addr = websocket.remote_address logger.info(f"[PTY] Operator shell connection from {addr}") shell = os.environ.get("SHELL", "/bin/bash") master_fd, slave_fd = pty.openpty() proc = await asyncio.create_subprocess_exec( shell, stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, preexec_fn=os.setsid, close_fds=True, ) os.close(slave_fd) loop = asyncio.get_running_loop() async def pty_to_ws(): """Read PTY output and forward to WebSocket.""" try: while True: data = await loop.run_in_executor(None, os.read, master_fd, 4096) if not data: break await websocket.send(data.decode("utf-8", errors="replace")) except (OSError, websockets.exceptions.ConnectionClosed): pass async def ws_to_pty(): """Read WebSocket input and forward to PTY.""" try: async for message in websocket: if isinstance(message, str): os.write(master_fd, message.encode("utf-8")) else: os.write(master_fd, message) except (OSError, websockets.exceptions.ConnectionClosed): pass reader = asyncio.ensure_future(pty_to_ws()) writer = asyncio.ensure_future(ws_to_pty()) try: await asyncio.gather(reader, writer) finally: reader.cancel() writer.cancel() try: os.close(master_fd) except OSError: pass try: proc.kill() except ProcessLookupError: pass await proc.wait() logger.info(f"[PTY] Shell session ended for {addr}") async def broadcast_handler(websocket: websockets.WebSocketServerProtocol): """Handles individual client connections and message broadcasting.""" addr = websocket.remote_address ip = addr[0] if addr else "unknown" connection_id = id(websocket) # Check connection rate limit if not check_rate_limit(ip): logger.warning(f"Connection rate limit exceeded for {ip}") await websocket.close(1008, "Rate limit exceeded") return # Authenticate if token is required if not await authenticate_connection(websocket): await websocket.close(1008, "Authentication failed") return clients.add(websocket) logger.info(f"Client connected from {addr}. Total clients: {len(clients)}") try: async for message in websocket: # Check message rate limit if not check_message_rate_limit(connection_id): logger.warning(f"Message rate limit exceeded for {addr}") await websocket.send(json.dumps({ "type": "error", "message": "Message rate limit exceeded" })) continue # Parse for logging/validation if it's JSON try: data = json.loads(message) msg_type = data.get("type", "unknown") # Optional: log specific important message types if msg_type in ["agent_register", "thought", "action"]: logger.debug(f"Received {msg_type} from {addr}") # Handle git status requests from the operator cockpit (issue #1695) if msg_type == "git_status_request": git_info = _get_git_status() await websocket.send(json.dumps(git_info)) continue except (json.JSONDecodeError, TypeError): pass # Broadcast to all OTHER clients if not clients: continue disconnected = set() # Create broadcast tasks, tracking which client each task targets task_client_pairs = [] for client in clients: if client != websocket and client.open: task = asyncio.create_task(client.send(message)) task_client_pairs.append((task, client)) if task_client_pairs: tasks = [pair[0] for pair in task_client_pairs] results = await asyncio.gather(*tasks, return_exceptions=True) for i, result in enumerate(results): if isinstance(result, Exception): target_client = task_client_pairs[i][1] logger.error(f"Failed to send to client {target_client.remote_address}: {result}") disconnected.add(target_client) if disconnected: clients.difference_update(disconnected) except websockets.exceptions.ConnectionClosed: logger.debug(f"Connection closed by client {addr}") except Exception as e: logger.error(f"Error handling client {addr}: {e}") finally: clients.discard(websocket) logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}") async def main(): """Main server loop with graceful shutdown.""" # Log security configuration if AUTH_TOKEN: logger.info("Authentication: ENABLED (token required)") else: logger.warning("Authentication: DISABLED (no token required)") if HOST == "0.0.0.0": logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK") else: logger.info(f"Host binding: {HOST} (localhost only)") logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, " f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s") logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}") # Set up signal handlers for graceful shutdown loop = asyncio.get_running_loop() stop = loop.create_future() def shutdown(): if not stop.done(): stop.set_result(None) for sig in (signal.SIGINT, signal.SIGTERM): try: loop.add_signal_handler(sig, shutdown) except NotImplementedError: # Signal handlers not supported on Windows pass async with websockets.serve(broadcast_handler, HOST, PORT): logger.info("Gateway is ready and listening.") # Also start the PTY gateway on PTY_PORT (operator cockpit shell, issue #1695) async with websockets.serve(pty_handler, "127.0.0.1", PTY_PORT): logger.info(f"PTY shell gateway listening on ws://127.0.0.1:{PTY_PORT}/pty") await stop logger.info("Shutting down Nexus WS gateway...") # Close any remaining client connections (handlers may have already cleaned up) remaining = {c for c in clients if c.open} if remaining: logger.info(f"Closing {len(remaining)} active connections...") close_tasks = [client.close() for client in remaining] await asyncio.gather(*close_tasks, return_exceptions=True) clients.clear() logger.info("Shutdown complete.") if __name__ == "__main__": try: asyncio.run(main()) except KeyboardInterrupt: pass except Exception as e: logger.critical(f"Fatal server error: {e}") sys.exit(1)