This commit addresses the security vulnerability where the WebSocket gateway was exposed on 0.0.0.0 without authentication. ## Changes ### Security Improvements 1. **Localhost binding by default**: Changed HOST from "0.0.0.0" to "127.0.0.1" - Gateway now only listens on localhost by default - External binding possible via NEXUS_WS_HOST environment variable 2. **Token-based authentication**: Added NEXUS_WS_TOKEN environment variable - If set, clients must send auth message with valid token - If not set, no authentication required (backward compatible) - Auth timeout: 5 seconds 3. **Rate limiting**: - Connection rate limiting: 10 connections per IP per 60 seconds - Message rate limiting: 100 messages per connection per 60 seconds - Configurable via constants 4. **Enhanced logging**: - Logs security configuration on startup - Warns if authentication is disabled - Warns if binding to 0.0.0.0 ### Configuration Environment variables: - NEXUS_WS_HOST: Host to bind to (default: 127.0.0.1) - NEXUS_WS_PORT: Port to listen on (default: 8765) - NEXUS_WS_TOKEN: Authentication token (empty = no auth) ### Backward Compatibility - Default behavior is now secure (localhost only) - No authentication by default (same as before) - Existing clients will work without changes - External binding possible via NEXUS_WS_HOST=0.0.0.0 ## Security Impact - Prevents unauthorized access from external networks - Prevents connection flooding - Prevents message flooding - Maintains backward compatibility Fixes #1504
234 lines
8.6 KiB
Python
234 lines
8.6 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
|
|
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
|
|
the body (Evennia/Morrowind), and the visualization surface.
|
|
|
|
Security features:
|
|
- Binds to 127.0.0.1 by default (localhost only)
|
|
- Optional external binding via NEXUS_WS_HOST environment variable
|
|
- Token-based authentication via NEXUS_WS_TOKEN environment variable
|
|
- Rate limiting on connections
|
|
- Connection logging and monitoring
|
|
"""
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import sys
|
|
import time
|
|
from typing import Set, Dict, Optional
|
|
from collections import defaultdict
|
|
|
|
# Branch protected file - see POLICY.md
|
|
import websockets
|
|
|
|
# Configuration
|
|
PORT = int(os.environ.get("NEXUS_WS_PORT", "8765"))
|
|
HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only
|
|
AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required
|
|
RATE_LIMIT_WINDOW = 60 # seconds
|
|
RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window
|
|
RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window
|
|
|
|
# Logging setup
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s [%(levelname)s] %(message)s',
|
|
datefmt='%Y-%m-%d %H:%M:%S'
|
|
)
|
|
logger = logging.getLogger("nexus-gateway")
|
|
|
|
# State
|
|
clients: Set[websockets.WebSocketServerProtocol] = set()
|
|
connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps]
|
|
message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps]
|
|
|
|
def check_rate_limit(ip: str) -> bool:
|
|
"""Check if IP has exceeded connection rate limit."""
|
|
now = time.time()
|
|
# Clean old entries
|
|
connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW]
|
|
|
|
if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS:
|
|
return False
|
|
|
|
connection_tracker[ip].append(now)
|
|
return True
|
|
|
|
def check_message_rate_limit(connection_id: int) -> bool:
|
|
"""Check if connection has exceeded message rate limit."""
|
|
now = time.time()
|
|
# Clean old entries
|
|
message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW]
|
|
|
|
if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES:
|
|
return False
|
|
|
|
message_tracker[connection_id].append(now)
|
|
return True
|
|
|
|
async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool:
|
|
"""Authenticate WebSocket connection using token."""
|
|
if not AUTH_TOKEN:
|
|
# No authentication required
|
|
return True
|
|
|
|
try:
|
|
# Wait for authentication message (first message should be auth)
|
|
auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
|
auth_data = json.loads(auth_message)
|
|
|
|
if auth_data.get("type") != "auth":
|
|
logger.warning(f"Invalid auth message type from {websocket.remote_address}")
|
|
return False
|
|
|
|
token = auth_data.get("token", "")
|
|
if token != AUTH_TOKEN:
|
|
logger.warning(f"Invalid auth token from {websocket.remote_address}")
|
|
return False
|
|
|
|
logger.info(f"Authenticated connection from {websocket.remote_address}")
|
|
return True
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"Authentication timeout from {websocket.remote_address}")
|
|
return False
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid auth JSON from {websocket.remote_address}")
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Authentication error from {websocket.remote_address}: {e}")
|
|
return False
|
|
|
|
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
|
|
"""Handles individual client connections and message broadcasting."""
|
|
addr = websocket.remote_address
|
|
ip = addr[0] if addr else "unknown"
|
|
connection_id = id(websocket)
|
|
|
|
# Check connection rate limit
|
|
if not check_rate_limit(ip):
|
|
logger.warning(f"Connection rate limit exceeded for {ip}")
|
|
await websocket.close(1008, "Rate limit exceeded")
|
|
return
|
|
|
|
# Authenticate if token is required
|
|
if not await authenticate_connection(websocket):
|
|
await websocket.close(1008, "Authentication failed")
|
|
return
|
|
|
|
clients.add(websocket)
|
|
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
|
|
|
|
try:
|
|
async for message in websocket:
|
|
# Check message rate limit
|
|
if not check_message_rate_limit(connection_id):
|
|
logger.warning(f"Message rate limit exceeded for {addr}")
|
|
await websocket.send(json.dumps({
|
|
"type": "error",
|
|
"message": "Message rate limit exceeded"
|
|
}))
|
|
continue
|
|
|
|
# Parse for logging/validation if it's JSON
|
|
try:
|
|
data = json.loads(message)
|
|
msg_type = data.get("type", "unknown")
|
|
# Optional: log specific important message types
|
|
if msg_type in ["agent_register", "thought", "action"]:
|
|
logger.debug(f"Received {msg_type} from {addr}")
|
|
except (json.JSONDecodeError, TypeError):
|
|
pass
|
|
|
|
# Broadcast to all OTHER clients
|
|
if not clients:
|
|
continue
|
|
|
|
disconnected = set()
|
|
# Create broadcast tasks, tracking which client each task targets
|
|
task_client_pairs = []
|
|
for client in clients:
|
|
if client != websocket and client.open:
|
|
task = asyncio.create_task(client.send(message))
|
|
task_client_pairs.append((task, client))
|
|
|
|
if task_client_pairs:
|
|
tasks = [pair[0] for pair in task_client_pairs]
|
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
for i, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
target_client = task_client_pairs[i][1]
|
|
logger.error(f"Failed to send to client {target_client.remote_address}: {result}")
|
|
disconnected.add(target_client)
|
|
|
|
if disconnected:
|
|
clients.difference_update(disconnected)
|
|
|
|
except websockets.exceptions.ConnectionClosed:
|
|
logger.debug(f"Connection closed by client {addr}")
|
|
except Exception as e:
|
|
logger.error(f"Error handling client {addr}: {e}")
|
|
finally:
|
|
clients.discard(websocket)
|
|
logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}")
|
|
|
|
async def main():
|
|
"""Main server loop with graceful shutdown."""
|
|
# Log security configuration
|
|
if AUTH_TOKEN:
|
|
logger.info("Authentication: ENABLED (token required)")
|
|
else:
|
|
logger.warning("Authentication: DISABLED (no token required)")
|
|
|
|
if HOST == "0.0.0.0":
|
|
logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK")
|
|
else:
|
|
logger.info(f"Host binding: {HOST} (localhost only)")
|
|
|
|
logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, "
|
|
f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s")
|
|
|
|
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
|
|
|
|
# Set up signal handlers for graceful shutdown
|
|
loop = asyncio.get_running_loop()
|
|
stop = loop.create_future()
|
|
|
|
def shutdown():
|
|
if not stop.done():
|
|
stop.set_result(None)
|
|
|
|
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
try:
|
|
loop.add_signal_handler(sig, shutdown)
|
|
except NotImplementedError:
|
|
# Signal handlers not supported on Windows
|
|
pass
|
|
|
|
async with websockets.serve(broadcast_handler, HOST, PORT):
|
|
logger.info("Gateway is ready and listening.")
|
|
await stop
|
|
|
|
logger.info("Shutting down Nexus WS gateway...")
|
|
# Close any remaining client connections (handlers may have already cleaned up)
|
|
remaining = {c for c in clients if c.open}
|
|
if remaining:
|
|
logger.info(f"Closing {len(remaining)} active connections...")
|
|
close_tasks = [client.close() for client in remaining]
|
|
await asyncio.gather(*close_tasks, return_exceptions=True)
|
|
clients.clear()
|
|
|
|
logger.info("Shutdown complete.")
|
|
|
|
if __name__ == "__main__":
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
pass
|
|
except Exception as e:
|
|
logger.critical(f"Fatal server error: {e}")
|
|
sys.exit(1)
|