Files
the-nexus/server.py
Claude (Opus 4.6) c97364ac13
Some checks failed
Deploy Nexus / deploy (push) Failing after 5s
Staging Verification Gate / verify-staging (push) Failing after 5s
[claude] ATLAS Cockpit: operator inspector rail and session shell (#1695) (#1696)
2026-04-22 05:19:13 +00:00

355 lines
13 KiB
Python

#!/usr/bin/env python3
"""
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
the body (Evennia/Morrowind), and the visualization surface.
Security features:
- Binds to 127.0.0.1 by default (localhost only)
- Optional external binding via NEXUS_WS_HOST environment variable
- Token-based authentication via NEXUS_WS_TOKEN environment variable
- Rate limiting on connections
- Connection logging and monitoring
"""
import asyncio
import json
import logging
import os
import pty
import signal
import subprocess
import sys
import time
from typing import Set, Dict, Optional
from collections import defaultdict
# Branch protected file - see POLICY.md
import websockets
# Configuration
PORT = int(os.environ.get("NEXUS_WS_PORT", "8765"))
PTY_PORT = int(os.environ.get("NEXUS_PTY_PORT", "8766")) # operator shell PTY gateway
HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only
AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required
RATE_LIMIT_WINDOW = 60 # seconds
RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window
RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window
# Logging setup
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger("nexus-gateway")
# State
clients: Set[websockets.WebSocketServerProtocol] = set()
connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps]
message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps]
def check_rate_limit(ip: str) -> bool:
"""Check if IP has exceeded connection rate limit."""
now = time.time()
# Clean old entries
connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW]
if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS:
return False
connection_tracker[ip].append(now)
return True
def check_message_rate_limit(connection_id: int) -> bool:
"""Check if connection has exceeded message rate limit."""
now = time.time()
# Clean old entries
message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW]
if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES:
return False
message_tracker[connection_id].append(now)
return True
async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool:
"""Authenticate WebSocket connection using token."""
if not AUTH_TOKEN:
# No authentication required
return True
try:
# Wait for authentication message (first message should be auth)
auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
auth_data = json.loads(auth_message)
if auth_data.get("type") != "auth":
logger.warning(f"Invalid auth message type from {websocket.remote_address}")
return False
token = auth_data.get("token", "")
if token != AUTH_TOKEN:
logger.warning(f"Invalid auth token from {websocket.remote_address}")
return False
logger.info(f"Authenticated connection from {websocket.remote_address}")
return True
except asyncio.TimeoutError:
logger.warning(f"Authentication timeout from {websocket.remote_address}")
return False
except json.JSONDecodeError:
logger.warning(f"Invalid auth JSON from {websocket.remote_address}")
return False
except Exception as e:
logger.error(f"Authentication error from {websocket.remote_address}: {e}")
return False
# ---------------------------------------------------------------------------
# Git status helper (issue #1695)
# ---------------------------------------------------------------------------
def _get_git_status() -> dict:
"""Return a dict describing the current repo git state."""
repo_root = os.path.dirname(os.path.abspath(__file__))
try:
branch_out = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"],
cwd=repo_root, stderr=subprocess.DEVNULL, text=True
).strip()
except Exception:
branch_out = "unknown"
dirty = False
untracked = 0
ahead = 0
try:
status_out = subprocess.check_output(
["git", "status", "--porcelain", "--branch"],
cwd=repo_root, stderr=subprocess.DEVNULL, text=True
)
for line in status_out.splitlines():
if line.startswith("##") and "ahead" in line:
import re
m = re.search(r"ahead (\d+)", line)
if m:
ahead = int(m.group(1))
elif line.startswith("??"):
untracked += 1
elif line and not line.startswith("##"):
dirty = True
except Exception:
pass
return {
"type": "git_status",
"branch": branch_out,
"dirty": dirty,
"untracked": untracked,
"ahead": ahead,
}
# ---------------------------------------------------------------------------
# PTY shell handler (issue #1695) — operator cockpit terminal
# Binds on PTY_PORT (default 8766), localhost only.
# Each WebSocket connection gets its own PTY subprocess.
# ---------------------------------------------------------------------------
async def pty_handler(websocket: websockets.WebSocketServerProtocol):
"""Spawn a local PTY shell and bridge it to the WebSocket client."""
addr = websocket.remote_address
logger.info(f"[PTY] Operator shell connection from {addr}")
shell = os.environ.get("SHELL", "/bin/bash")
master_fd, slave_fd = pty.openpty()
proc = await asyncio.create_subprocess_exec(
shell,
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
preexec_fn=os.setsid,
close_fds=True,
)
os.close(slave_fd)
loop = asyncio.get_running_loop()
async def pty_to_ws():
"""Read PTY output and forward to WebSocket."""
try:
while True:
data = await loop.run_in_executor(None, os.read, master_fd, 4096)
if not data:
break
await websocket.send(data.decode("utf-8", errors="replace"))
except (OSError, websockets.exceptions.ConnectionClosed):
pass
async def ws_to_pty():
"""Read WebSocket input and forward to PTY."""
try:
async for message in websocket:
if isinstance(message, str):
os.write(master_fd, message.encode("utf-8"))
else:
os.write(master_fd, message)
except (OSError, websockets.exceptions.ConnectionClosed):
pass
reader = asyncio.ensure_future(pty_to_ws())
writer = asyncio.ensure_future(ws_to_pty())
try:
await asyncio.gather(reader, writer)
finally:
reader.cancel()
writer.cancel()
try:
os.close(master_fd)
except OSError:
pass
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
logger.info(f"[PTY] Shell session ended for {addr}")
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
"""Handles individual client connections and message broadcasting."""
addr = websocket.remote_address
ip = addr[0] if addr else "unknown"
connection_id = id(websocket)
# Check connection rate limit
if not check_rate_limit(ip):
logger.warning(f"Connection rate limit exceeded for {ip}")
await websocket.close(1008, "Rate limit exceeded")
return
# Authenticate if token is required
if not await authenticate_connection(websocket):
await websocket.close(1008, "Authentication failed")
return
clients.add(websocket)
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
try:
async for message in websocket:
# Check message rate limit
if not check_message_rate_limit(connection_id):
logger.warning(f"Message rate limit exceeded for {addr}")
await websocket.send(json.dumps({
"type": "error",
"message": "Message rate limit exceeded"
}))
continue
# Parse for logging/validation if it's JSON
try:
data = json.loads(message)
msg_type = data.get("type", "unknown")
# Optional: log specific important message types
if msg_type in ["agent_register", "thought", "action"]:
logger.debug(f"Received {msg_type} from {addr}")
# Handle git status requests from the operator cockpit (issue #1695)
if msg_type == "git_status_request":
git_info = _get_git_status()
await websocket.send(json.dumps(git_info))
continue
except (json.JSONDecodeError, TypeError):
pass
# Broadcast to all OTHER clients
if not clients:
continue
disconnected = set()
# Create broadcast tasks, tracking which client each task targets
task_client_pairs = []
for client in clients:
if client != websocket and client.open:
task = asyncio.create_task(client.send(message))
task_client_pairs.append((task, client))
if task_client_pairs:
tasks = [pair[0] for pair in task_client_pairs]
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
target_client = task_client_pairs[i][1]
logger.error(f"Failed to send to client {target_client.remote_address}: {result}")
disconnected.add(target_client)
if disconnected:
clients.difference_update(disconnected)
except websockets.exceptions.ConnectionClosed:
logger.debug(f"Connection closed by client {addr}")
except Exception as e:
logger.error(f"Error handling client {addr}: {e}")
finally:
clients.discard(websocket)
logger.info(f"Client disconnected {addr}. Total clients: {len(clients)}")
async def main():
"""Main server loop with graceful shutdown."""
# Log security configuration
if AUTH_TOKEN:
logger.info("Authentication: ENABLED (token required)")
else:
logger.warning("Authentication: DISABLED (no token required)")
if HOST == "0.0.0.0":
logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK")
else:
logger.info(f"Host binding: {HOST} (localhost only)")
logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, "
f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s")
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
# Set up signal handlers for graceful shutdown
loop = asyncio.get_running_loop()
stop = loop.create_future()
def shutdown():
if not stop.done():
stop.set_result(None)
for sig in (signal.SIGINT, signal.SIGTERM):
try:
loop.add_signal_handler(sig, shutdown)
except NotImplementedError:
# Signal handlers not supported on Windows
pass
async with websockets.serve(broadcast_handler, HOST, PORT):
logger.info("Gateway is ready and listening.")
# Also start the PTY gateway on PTY_PORT (operator cockpit shell, issue #1695)
async with websockets.serve(pty_handler, "127.0.0.1", PTY_PORT):
logger.info(f"PTY shell gateway listening on ws://127.0.0.1:{PTY_PORT}/pty")
await stop
logger.info("Shutting down Nexus WS gateway...")
# Close any remaining client connections (handlers may have already cleaned up)
remaining = {c for c in clients if c.open}
if remaining:
logger.info(f"Closing {len(remaining)} active connections...")
close_tasks = [client.close() for client in remaining]
await asyncio.gather(*close_tasks, return_exceptions=True)
clients.clear()
logger.info("Shutdown complete.")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
pass
except Exception as e:
logger.critical(f"Fatal server error: {e}")
sys.exit(1)