355 lines
13 KiB
Python
355 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
|
|
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
|
|
the body (Evennia/Morrowind), and the visualization surface.
|
|
|
|
Security features:
|
|
- Binds to 127.0.0.1 by default (localhost only)
|
|
- Optional external binding via NEXUS_WS_HOST environment variable
|
|
- Token-based authentication via NEXUS_WS_TOKEN environment variable
|
|
- Rate limiting on connections
|
|
- Connection logging and monitoring
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import pty
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from typing import Set, Dict, Optional
|
|
from collections import defaultdict
|
|
|
|
# Branch protected file - see POLICY.md
|
|
import websockets
|
|
|
|
# Configuration
|
|
PORT = int(os.environ.get("NEXUS_WS_PORT", "8765"))
|
|
PTY_PORT = int(os.environ.get("NEXUS_PTY_PORT", "8766")) # operator shell PTY gateway
|
|
HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only
|
|
AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required
|
|
RATE_LIMIT_WINDOW = 60 # seconds
|
|
RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window
|
|
RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window
|
|
|
|
# Logging setup
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger("nexus-gateway")
|
|
|
|
# State
|
|
clients: Set[websockets.WebSocketServerProtocol] = set()
|
|
connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps]
|
|
message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps]
|
|
|
|
def check_rate_limit(ip: str) -> bool:
|
|
"""Check if IP has exceeded connection rate limit."""
|
|
now = time.time()
|
|
# Clean old entries
|
|
connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW]
|
|
|
|
if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS:
|
|
return False
|
|
|
|
connection_tracker[ip].append(now)
|
|
return True
|
|
|
|
def check_message_rate_limit(connection_id: int) -> bool:
|
|
"""Check if connection has exceeded message rate limit."""
|
|
now = time.time()
|
|
# Clean old entries
|
|
message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW]
|
|
|
|
if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES:
|
|
return False
|
|
|
|
message_tracker[connection_id].append(now)
|
|
return True
|
|
|
|
async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool:
|
|
"""Authenticate WebSocket connection using token."""
|
|
if not AUTH_TOKEN:
|
|
# No authentication required
|
|
return True
|
|
|
|
try:
|
|
# Wait for authentication message (first message should be auth)
|
|
auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
|
auth_data = json.loads(auth_message)
|
|
|
|
if auth_data.get("type") != "auth":
|
|
logger.warning(f"Invalid auth message type from {websocket.remote_address}")
|
|
return False
|
|
|
|
token = auth_data.get("token", "")
|
|
if token != AUTH_TOKEN:
|
|
logger.warning(f"Invalid auth token from {websocket.remote_address}")
|
|
return False
|
|
|
|
logger.info(f"Authenticated connection from {websocket.remote_address}")
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Authentication timeout from {websocket.remote_address}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid auth JSON from {websocket.remote_address}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Authentication error from {websocket.remote_address}: {e}")
|
|
return False
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Git status helper (issue #1695)
|
|
# ---------------------------------------------------------------------------
|
|
def _get_git_status() -> dict:
|
|
"""Return a dict describing the current repo git state."""
|
|
repo_root = os.path.dirname(os.path.abspath(__file__))
|
|
try:
|
|
branch_out = subprocess.check_output(
|
|
["git", "rev-parse", "--abbrev-ref", "HEAD"],
|
|
cwd=repo_root, stderr=subprocess.DEVNULL, text=True
|
|
).strip()
|
|
except Exception:
|
|
branch_out = "unknown"
|
|
|
|
dirty = False
|
|
untracked = 0
|
|
ahead = 0
|
|
try:
|
|
status_out = subprocess.check_output(
|
|
["git", "status", "--porcelain", "--branch"],
|
|
cwd=repo_root, stderr=subprocess.DEVNULL, text=True
|
|
)
|
|
for line in status_out.splitlines():
|
|
if line.startswith("##") and "ahead" in line:
|
|
import re
|
|
m = re.search(r"ahead (\d+)", line)
|
|
if m:
|
|
ahead = int(m.group(1))
|
|
elif line.startswith("??"):
|
|
untracked += 1
|
|
elif line and not line.startswith("##"):
|
|
dirty = True
|
|
except Exception:
|
|
pass
|
|
|
|
return {
|
|
"type": "git_status",
|
|
"branch": branch_out,
|
|
"dirty": dirty,
|
|
"untracked": untracked,
|
|
"ahead": ahead,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# PTY shell handler (issue #1695) — operator cockpit terminal
|
|
# Binds on PTY_PORT (default 8766), localhost only.
|
|
# Each WebSocket connection gets its own PTY subprocess.
|
|
# ---------------------------------------------------------------------------
|
|
async def pty_handler(websocket: websockets.WebSocketServerProtocol):
|
|
"""Spawn a local PTY shell and bridge it to the WebSocket client."""
|
|
addr = websocket.remote_address
|
|
logger.info(f"[PTY] Operator shell connection from {addr}")
|
|
|
|
shell = os.environ.get("SHELL", "/bin/bash")
|
|
master_fd, slave_fd = pty.openpty()
|
|
|
|
proc = await asyncio.create_subprocess_exec(
|
|
shell,
|
|
stdin=slave_fd,
|
|
stdout=slave_fd,
|
|
stderr=slave_fd,
|
|
preexec_fn=os.setsid,
|
|
close_fds=True,
|
|
)
|
|
os.close(slave_fd)
|
|
|
|
loop = asyncio.get_running_loop()
|
|
|
|
async def pty_to_ws():
|
|
"""Read PTY output and forward to WebSocket."""
|
|
try:
|
|
while True:
|
|
data = await loop.run_in_executor(None, os.read, master_fd, 4096)
|
|
if not data:
|
|
break
|
|
await websocket.send(data.decode("utf-8", errors="replace"))
|
|
except (OSError, websockets.exceptions.ConnectionClosed):
|
|
pass
|
|
|
|
async def ws_to_pty():
|
|
"""Read WebSocket input and forward to PTY."""
|
|
try:
|
|
async for message in websocket:
|
|
if isinstance(message, str):
|
|
os.write(master_fd, message.encode("utf-8"))
|
|
else:
|
|
os.write(master_fd, message)
|
|
except (OSError, websockets.exceptions.ConnectionClosed):
|
|
pass
|
|
|
|
reader = asyncio.ensure_future(pty_to_ws())
|
|
writer = asyncio.ensure_future(ws_to_pty())
|
|
try:
|
|
await asyncio.gather(reader, writer)
|
|
finally:
|
|
reader.cancel()
|
|
writer.cancel()
|
|
try:
|
|
os.close(master_fd)
|
|
except OSError:
|
|
pass
|
|
try:
|
|
proc.kill()
|
|
except ProcessLookupError:
|
|
pass
|
|
await proc.wait()
|
|
logger.info(f"[PTY] Shell session ended for {addr}")
|
|
|
|
|
|
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
|
|
"""Handles individual client connections and message broadcasting."""
|
|
addr = websocket.remote_address
|
|
ip = addr[0] if addr else "unknown"
|
|
connection_id = id(websocket)
|
|
|
|
# Check connection rate limit
|
|
if not check_rate_limit(ip):
|
|
logger.warning(f"Connection rate limit exceeded for {ip}")
|
|
await websocket.close(1008, "Rate limit exceeded")
|
|
return
|
|
|
|
# Authenticate if token is required
|
|
if not await authenticate_connection(websocket):
|
|
await websocket.close(1008, "Authentication failed")
|
|
return
|
|
|
|
clients.add(websocket)
|
|
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
|
|
|
|
try:
|
|
async for message in websocket:
|
|
# Check message rate limit
|
|
if not check_message_rate_limit(connection_id):
|
|
logger.warning(f"Message rate limit exceeded for {addr}")
|
|
await websocket.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Message rate limit exceeded"
|
|
}))
|
|
continue
|
|
|
|
# Parse for logging/validation if it's JSON
|
|
try:
|
|
data = json.loads(message)
|
|
msg_type = data.get("type", "unknown")
|
|
# Optional: log specific important message types
|
|
if msg_type in ["agent_register", "thought", "action"]:
|
|
logger.debug(f"Received {msg_type} from {addr}")
|
|
# Handle git status requests from the operator cockpit (issue #1695)
|
|
if msg_type == "git_status_request":
|
|
git_info = _get_git_status()
|
|
await websocket.send(json.dumps(git_info))
|
|
continue
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
# Broadcast to all OTHER clients
|
|
if not clients:
|
|
continue
|
|
|
|
disconnected = set()
|
|
# Create broadcast tasks, tracking which client each task targets
|
|
task_client_pairs = []
|
|
for client in clients:
|
|
if client != websocket and client.open:
|
|
task = asyncio.create_task(client.send(message))
|
|
task_client_pairs.append((task, client))
|
|
|
|
if task_client_pairs:
|
|
tasks = [pair[0] for pair in task_client_pairs]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
target_client = task_client_pairs[i][1]
|
|
logger.error(f"Failed to send to client {target_client.remote_address}: {result}")
|
|
disconnected.add(target_client)
|
|
|
|
if disconnected:
|
|
clients.difference_update(disconnected)
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.debug(f"Connection closed by client {addr}")
|
|
except Exception as e:
|
|
logger.error(f"Error handling client {addr}: {e}")
|
|
finally:
|
|
clients.discard(websocket)
|
|
logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}")
|
|
|
|
async def main():
|
|
"""Main server loop with graceful shutdown."""
|
|
# Log security configuration
|
|
if AUTH_TOKEN:
|
|
logger.info("Authentication: ENABLED (token required)")
|
|
else:
|
|
logger.warning("Authentication: DISABLED (no token required)")
|
|
|
|
if HOST == "0.0.0.0":
|
|
logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK")
|
|
else:
|
|
logger.info(f"Host binding: {HOST} (localhost only)")
|
|
|
|
logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, "
|
|
f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s")
|
|
|
|
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
|
|
|
|
# Set up signal handlers for graceful shutdown
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
def shutdown():
|
|
if not stop.done():
|
|
stop.set_result(None)
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, shutdown)
|
|
except NotImplementedError:
|
|
# Signal handlers not supported on Windows
|
|
pass
|
|
|
|
async with websockets.serve(broadcast_handler, HOST, PORT):
|
|
logger.info("Gateway is ready and listening.")
|
|
# Also start the PTY gateway on PTY_PORT (operator cockpit shell, issue #1695)
|
|
async with websockets.serve(pty_handler, "127.0.0.1", PTY_PORT):
|
|
logger.info(f"PTY shell gateway listening on ws://127.0.0.1:{PTY_PORT}/pty")
|
|
await stop
|
|
|
|
logger.info("Shutting down Nexus WS gateway...")
|
|
# Close any remaining client connections (handlers may have already cleaned up)
|
|
remaining = {c for c in clients if c.open}
|
|
if remaining:
|
|
logger.info(f"Closing {len(remaining)} active connections...")
|
|
close_tasks = [client.close() for client in remaining]
|
|
await asyncio.gather(*close_tasks, return_exceptions=True)
|
|
clients.clear()
|
|
|
|
logger.info("Shutdown complete.")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
logger.critical(f"Fatal server error: {e}")
|
|
sys.exit(1)
|