Compare commits
10 Commits
mimo/code/
...
fix/882
| Author | SHA1 | Date | |
|---|---|---|---|
| 792b06e669 | |||
| d1f6421c49 | |||
| 8d87dba309 | |||
| 9322742ef8 | |||
| 157f6f322d | |||
| 2978f48a6a | |||
| 9a7e31030d | |||
|
|
8f2dd27447 | ||
|
|
3fed634955 | ||
|
|
b79805118e |
579
agent/resurrection_pool.py
Normal file
579
agent/resurrection_pool.py
Normal file
@@ -0,0 +1,579 @@
|
|||||||
|
"""
|
||||||
|
Resurrection Pool — Health polling, dead-agent detection, auto-revive
|
||||||
|
Issue #882: [M6-P3] Resurrection Pool — health polling, dead-agent detection, auto-revive
|
||||||
|
|
||||||
|
Implement the actual resurrection pool: a polling loop that detects downed agents
|
||||||
|
and can automatically revive them (or substitutes) back into active missions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Set
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
logger = logging.getLogger("hermes.resurrection_pool")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStatus(Enum):
|
||||||
|
"""Agent status in the resurrection pool."""
|
||||||
|
HEALTHY = "healthy"
|
||||||
|
DEGRADED = "degraded"
|
||||||
|
DOWN = "down"
|
||||||
|
REVIVING = "reviving"
|
||||||
|
REVIVED = "revived"
|
||||||
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
|
class RevivePolicy(Enum):
|
||||||
|
"""Auto-revive policy for missions."""
|
||||||
|
YES = "yes" # Always auto-revive
|
||||||
|
NO = "no" # Never auto-revive
|
||||||
|
ASK = "ask" # Ask human for approval
|
||||||
|
SUBSTITUTE = "substitute" # Substitute with different agent
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentHealth:
|
||||||
|
"""Health status of an agent."""
|
||||||
|
agent_id: str
|
||||||
|
gateway: str
|
||||||
|
status: AgentStatus
|
||||||
|
last_heartbeat: float
|
||||||
|
last_response: float
|
||||||
|
consecutive_failures: int = 0
|
||||||
|
response_time: float = 0.0
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MissionPolicy:
|
||||||
|
"""Revive policy for a mission."""
|
||||||
|
mission_id: str
|
||||||
|
policy: RevivePolicy
|
||||||
|
timeout: int = 300 # seconds
|
||||||
|
substitute_agents: List[str] = field(default_factory=list)
|
||||||
|
approval_required: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReviveRequest:
|
||||||
|
"""Request to revive an agent."""
|
||||||
|
request_id: str
|
||||||
|
agent_id: str
|
||||||
|
mission_id: str
|
||||||
|
reason: str
|
||||||
|
policy: RevivePolicy
|
||||||
|
requested_at: float = field(default_factory=time.time)
|
||||||
|
approved: Optional[bool] = None
|
||||||
|
approved_by: Optional[str] = None
|
||||||
|
approved_at: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class HealthPoller:
|
||||||
|
"""Health polling loop across wizard gateways."""
|
||||||
|
|
||||||
|
def __init__(self, gateways: List[str], poll_interval: int = 30):
|
||||||
|
self.gateways = gateways
|
||||||
|
self.poll_interval = poll_interval
|
||||||
|
self.agent_health: Dict[str, AgentHealth] = {}
|
||||||
|
self.running = False
|
||||||
|
self.poll_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start health polling."""
|
||||||
|
self.running = True
|
||||||
|
logger.info(f"Starting health polling across {len(self.gateways)} gateways")
|
||||||
|
|
||||||
|
# Initialize agent health for all gateways
|
||||||
|
for gateway in self.gateways:
|
||||||
|
self.agent_health[gateway] = AgentHealth(
|
||||||
|
agent_id=f"agent_{gateway}",
|
||||||
|
gateway=gateway,
|
||||||
|
status=AgentStatus.HEALTHY,
|
||||||
|
last_heartbeat=time.time(),
|
||||||
|
last_response=time.time()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start polling loop
|
||||||
|
self.poll_task = asyncio.create_task(self._poll_loop())
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop health polling."""
|
||||||
|
self.running = False
|
||||||
|
if self.poll_task:
|
||||||
|
self.poll_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.poll_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
logger.info("Health polling stopped")
|
||||||
|
|
||||||
|
async def _poll_loop(self):
|
||||||
|
"""Main polling loop."""
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
await self._poll_all_gateways()
|
||||||
|
await asyncio.sleep(self.poll_interval)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Polling error: {e}")
|
||||||
|
await asyncio.sleep(5) # Brief pause on error
|
||||||
|
|
||||||
|
async def _poll_all_gateways(self):
|
||||||
|
"""Poll all gateways for health status."""
|
||||||
|
tasks = []
|
||||||
|
for gateway in self.gateways:
|
||||||
|
task = asyncio.create_task(self._poll_gateway(gateway))
|
||||||
|
tasks.append(task)
|
||||||
|
|
||||||
|
# Wait for all polls to complete
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
async def _poll_gateway(self, gateway: str):
|
||||||
|
"""Poll a single gateway for health."""
|
||||||
|
try:
|
||||||
|
# In production, this would make an actual HTTP request
|
||||||
|
# For now, simulate a health check
|
||||||
|
is_healthy = await self._check_gateway_health(gateway)
|
||||||
|
|
||||||
|
agent = self.agent_health.get(gateway)
|
||||||
|
if not agent:
|
||||||
|
return
|
||||||
|
|
||||||
|
if is_healthy:
|
||||||
|
agent.status = AgentStatus.HEALTHY
|
||||||
|
agent.last_response = time.time()
|
||||||
|
agent.consecutive_failures = 0
|
||||||
|
agent.error_message = None
|
||||||
|
else:
|
||||||
|
agent.consecutive_failures += 1
|
||||||
|
agent.last_response = time.time()
|
||||||
|
|
||||||
|
if agent.consecutive_failures >= 3:
|
||||||
|
agent.status = AgentStatus.DOWN
|
||||||
|
agent.error_message = f"Failed {agent.consecutive_failures} consecutive health checks"
|
||||||
|
else:
|
||||||
|
agent.status = AgentStatus.DEGRADED
|
||||||
|
|
||||||
|
agent.last_heartbeat = time.time()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error polling gateway {gateway}: {e}")
|
||||||
|
agent = self.agent_health.get(gateway)
|
||||||
|
if agent:
|
||||||
|
agent.status = AgentStatus.DOWN
|
||||||
|
agent.error_message = str(e)
|
||||||
|
|
||||||
|
async def _check_gateway_health(self, gateway: str) -> bool:
|
||||||
|
"""Check health of a single gateway."""
|
||||||
|
# In production, this would:
|
||||||
|
# 1. Make HTTP request to gateway health endpoint
|
||||||
|
# 2. Check response time
|
||||||
|
# 3. Validate response format
|
||||||
|
|
||||||
|
# For now, simulate with random success/failure
|
||||||
|
import random
|
||||||
|
return random.random() > 0.1 # 90% success rate
|
||||||
|
|
||||||
|
def get_health_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get health status of all agents."""
|
||||||
|
return {
|
||||||
|
gateway: {
|
||||||
|
"agent_id": agent.agent_id,
|
||||||
|
"status": agent.status.value,
|
||||||
|
"last_heartbeat": agent.last_heartbeat,
|
||||||
|
"last_response": agent.last_response,
|
||||||
|
"consecutive_failures": agent.consecutive_failures,
|
||||||
|
"response_time": agent.response_time,
|
||||||
|
"error_message": agent.error_message
|
||||||
|
}
|
||||||
|
for gateway, agent in self.agent_health.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_down_agents(self) -> List[AgentHealth]:
|
||||||
|
"""Get list of down agents."""
|
||||||
|
return [agent for agent in self.agent_health.values()
|
||||||
|
if agent.status == AgentStatus.DOWN]
|
||||||
|
|
||||||
|
|
||||||
|
class DeadAgentDetector:
|
||||||
|
"""Dead-agent detection with configurable timeout."""
|
||||||
|
|
||||||
|
def __init__(self, health_poller: HealthPoller, timeout: int = 300):
|
||||||
|
self.health_poller = health_poller
|
||||||
|
self.timeout = timeout # seconds
|
||||||
|
self.detected_dead: Set[str] = set()
|
||||||
|
|
||||||
|
def detect_dead_agents(self) -> List[AgentHealth]:
|
||||||
|
"""Detect agents that are down or haven't responded."""
|
||||||
|
dead_agents = []
|
||||||
|
|
||||||
|
for gateway, agent in self.health_poller.agent_health.items():
|
||||||
|
# Check if agent is marked as down
|
||||||
|
if agent.status == AgentStatus.DOWN:
|
||||||
|
dead_agents.append(agent)
|
||||||
|
self.detected_dead.add(gateway)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check if agent hasn't responded within timeout
|
||||||
|
time_since_response = time.time() - agent.last_response
|
||||||
|
if time_since_response > self.timeout:
|
||||||
|
agent.status = AgentStatus.DOWN
|
||||||
|
agent.error_message = f"No response for {time_since_response:.0f} seconds"
|
||||||
|
dead_agents.append(agent)
|
||||||
|
self.detected_dead.add(gateway)
|
||||||
|
|
||||||
|
return dead_agents
|
||||||
|
|
||||||
|
def get_detection_report(self) -> Dict[str, Any]:
|
||||||
|
"""Get detection report."""
|
||||||
|
dead_agents = self.detect_dead_agents()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"total_agents": len(self.health_poller.agent_health),
|
||||||
|
"dead_agents": len(dead_agents),
|
||||||
|
"dead_agent_ids": [agent.agent_id for agent in dead_agents],
|
||||||
|
"timeout_seconds": self.timeout,
|
||||||
|
"previously_detected": len(self.detected_dead)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AutoRevivePolicyEngine:
|
||||||
|
"""Auto-revive policy engine (yes/no/ask per mission)."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.mission_policies: Dict[str, MissionPolicy] = {}
|
||||||
|
self.default_policy = RevivePolicy.ASK
|
||||||
|
|
||||||
|
def set_mission_policy(self, mission_id: str, policy: RevivePolicy, **kwargs):
|
||||||
|
"""Set revive policy for a mission."""
|
||||||
|
self.mission_policies[mission_id] = MissionPolicy(
|
||||||
|
mission_id=mission_id,
|
||||||
|
policy=policy,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
logger.info(f"Set revive policy for mission {mission_id}: {policy.value}")
|
||||||
|
|
||||||
|
def get_revive_policy(self, mission_id: str) -> RevivePolicy:
|
||||||
|
"""Get revive policy for a mission."""
|
||||||
|
policy = self.mission_policies.get(mission_id)
|
||||||
|
return policy.policy if policy else self.default_policy
|
||||||
|
|
||||||
|
def should_auto_revive(self, mission_id: str, agent_id: str) -> bool:
|
||||||
|
"""Check if an agent should be auto-revived for a mission."""
|
||||||
|
policy = self.get_revive_policy(mission_id)
|
||||||
|
|
||||||
|
if policy == RevivePolicy.YES:
|
||||||
|
return True
|
||||||
|
elif policy == RevivePolicy.NO:
|
||||||
|
return False
|
||||||
|
elif policy == RevivePolicy.ASK:
|
||||||
|
return False # Requires human approval
|
||||||
|
elif policy == RevivePolicy.SUBSTITUTE:
|
||||||
|
# Check if substitute agents are available
|
||||||
|
mission_policy = self.mission_policies.get(mission_id)
|
||||||
|
if mission_policy and mission_policy.substitute_agents:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_substitute_agent(self, mission_id: str, dead_agent_id: str) -> Optional[str]:
|
||||||
|
"""Get substitute agent for a dead agent."""
|
||||||
|
mission_policy = self.mission_policies.get(mission_id)
|
||||||
|
if not mission_policy or not mission_policy.substitute_agents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return first available substitute
|
||||||
|
for substitute in mission_policy.substitute_agents:
|
||||||
|
if substitute != dead_agent_id:
|
||||||
|
return substitute
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class HumanInTheLoopApproval:
|
||||||
|
"""Human-in-the-loop revival via Telegram / Nostr approval."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.pending_requests: Dict[str, ReviveRequest] = {}
|
||||||
|
self.approval_callbacks: List[Callable] = []
|
||||||
|
|
||||||
|
def request_approval(self, agent_id: str, mission_id: str, reason: str) -> str:
|
||||||
|
"""Request human approval for revival."""
|
||||||
|
request_id = f"revive_{int(time.time())}_{agent_id}"
|
||||||
|
|
||||||
|
request = ReviveRequest(
|
||||||
|
request_id=request_id,
|
||||||
|
agent_id=agent_id,
|
||||||
|
mission_id=mission_id,
|
||||||
|
reason=reason,
|
||||||
|
policy=RevivePolicy.ASK
|
||||||
|
)
|
||||||
|
|
||||||
|
self.pending_requests[request_id] = request
|
||||||
|
logger.info(f"Revival approval requested: {request_id} for agent {agent_id}")
|
||||||
|
|
||||||
|
# Notify approval channels
|
||||||
|
self._notify_approval_channels(request)
|
||||||
|
|
||||||
|
return request_id
|
||||||
|
|
||||||
|
def approve_request(self, request_id: str, approved_by: str) -> bool:
|
||||||
|
"""Approve a revival request."""
|
||||||
|
if request_id not in self.pending_requests:
|
||||||
|
raise ValueError(f"Request {request_id} not found")
|
||||||
|
|
||||||
|
request = self.pending_requests[request_id]
|
||||||
|
request.approved = True
|
||||||
|
request.approved_by = approved_by
|
||||||
|
request.approved_at = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Revival approved: {request_id} by {approved_by}")
|
||||||
|
|
||||||
|
# Trigger callbacks
|
||||||
|
for callback in self.approval_callbacks:
|
||||||
|
try:
|
||||||
|
callback(request)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Approval callback error: {e}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def reject_request(self, request_id: str, rejected_by: str, reason: str = "") -> bool:
|
||||||
|
"""Reject a revival request."""
|
||||||
|
if request_id not in self.pending_requests:
|
||||||
|
raise ValueError(f"Request {request_id} not found")
|
||||||
|
|
||||||
|
request = self.pending_requests[request_id]
|
||||||
|
request.approved = False
|
||||||
|
request.approved_by = rejected_by
|
||||||
|
request.approved_at = time.time()
|
||||||
|
|
||||||
|
logger.info(f"Revival rejected: {request_id} by {rejected_by}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _notify_approval_channels(self, request: ReviveRequest):
|
||||||
|
"""Notify approval channels (Telegram, Nostr, etc.)."""
|
||||||
|
# In production, this would:
|
||||||
|
# 1. Send Telegram message to approval group
|
||||||
|
# 2. Post to Nostr for decentralized approval
|
||||||
|
# 3. Send email to administrators
|
||||||
|
|
||||||
|
logger.info(f"Approval notification sent for request {request.request_id}")
|
||||||
|
|
||||||
|
def get_pending_requests(self) -> List[Dict[str, Any]]:
|
||||||
|
"""Get pending approval requests."""
|
||||||
|
requests = []
|
||||||
|
for request_id, request in self.pending_requests.items():
|
||||||
|
if request.approved is None: # Still pending
|
||||||
|
requests.append({
|
||||||
|
"request_id": request.request_id,
|
||||||
|
"agent_id": request.agent_id,
|
||||||
|
"mission_id": request.mission_id,
|
||||||
|
"reason": request.reason,
|
||||||
|
"requested_at": request.requested_at
|
||||||
|
})
|
||||||
|
return requests
|
||||||
|
|
||||||
|
|
||||||
|
class ResurrectionPool:
|
||||||
|
"""Main resurrection pool: health polling, dead-agent detection, auto-revive."""
|
||||||
|
|
||||||
|
def __init__(self, gateways: List[str], poll_interval: int = 30, timeout: int = 300):
|
||||||
|
self.health_poller = HealthPoller(gateways, poll_interval)
|
||||||
|
self.dead_agent_detector = DeadAgentDetector(self.health_poller, timeout)
|
||||||
|
self.policy_engine = AutoRevivePolicyEngine()
|
||||||
|
self.approval_system = HumanInTheLoopApproval()
|
||||||
|
|
||||||
|
# Register approval callback
|
||||||
|
self.approval_system.approval_callbacks.append(self._on_approval)
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Start the resurrection pool."""
|
||||||
|
logger.info("Starting resurrection pool")
|
||||||
|
await self.health_poller.start()
|
||||||
|
|
||||||
|
async def stop(self):
|
||||||
|
"""Stop the resurrection pool."""
|
||||||
|
logger.info("Stopping resurrection pool")
|
||||||
|
await self.health_poller.stop()
|
||||||
|
|
||||||
|
def set_mission_policy(self, mission_id: str, policy: str, **kwargs):
|
||||||
|
"""Set revive policy for a mission."""
|
||||||
|
policy_enum = RevivePolicy(policy)
|
||||||
|
self.policy_engine.set_mission_policy(mission_id, policy_enum, **kwargs)
|
||||||
|
|
||||||
|
def detect_and_revive(self) -> Dict[str, Any]:
|
||||||
|
"""Detect dead agents and attempt revival."""
|
||||||
|
# Detect dead agents
|
||||||
|
dead_agents = self.dead_agent_detector.detect_dead_agents()
|
||||||
|
|
||||||
|
if not dead_agents:
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"dead_agents": 0,
|
||||||
|
"revived": 0,
|
||||||
|
"pending_approval": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Detected {len(dead_agents)} dead agents")
|
||||||
|
|
||||||
|
# Process each dead agent
|
||||||
|
revived = 0
|
||||||
|
pending_approval = 0
|
||||||
|
|
||||||
|
for agent in dead_agents:
|
||||||
|
# Get mission for this agent (simplified)
|
||||||
|
mission_id = f"mission_{agent.gateway}"
|
||||||
|
|
||||||
|
# Check revive policy
|
||||||
|
should_revive = self.policy_engine.should_auto_revive(mission_id, agent.agent_id)
|
||||||
|
|
||||||
|
if should_revive:
|
||||||
|
# Auto-revive
|
||||||
|
if self._revive_agent(agent.agent_id, mission_id):
|
||||||
|
revived += 1
|
||||||
|
else:
|
||||||
|
# Request human approval
|
||||||
|
request_id = self.approval_system.request_approval(
|
||||||
|
agent.agent_id,
|
||||||
|
mission_id,
|
||||||
|
f"Agent {agent.agent_id} is down: {agent.error_message}"
|
||||||
|
)
|
||||||
|
pending_approval += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"status": "processing",
|
||||||
|
"dead_agents": len(dead_agents),
|
||||||
|
"revived": revived,
|
||||||
|
"pending_approval": pending_approval,
|
||||||
|
"dead_agent_ids": [agent.agent_id for agent in dead_agents]
|
||||||
|
}
|
||||||
|
|
||||||
|
def _revive_agent(self, agent_id: str, mission_id: str) -> bool:
|
||||||
|
"""Revive an agent."""
|
||||||
|
logger.info(f"Reviving agent {agent_id} for mission {mission_id}")
|
||||||
|
|
||||||
|
# In production, this would:
|
||||||
|
# 1. Check if agent can be revived
|
||||||
|
# 2. Restart agent process/container
|
||||||
|
# 3. Restore from checkpoint
|
||||||
|
# 4. Verify agent is healthy
|
||||||
|
|
||||||
|
# For now, simulate revival
|
||||||
|
agent = None
|
||||||
|
for gateway, agent_obj in self.health_poller.agent_health.items():
|
||||||
|
if agent_obj.agent_id == agent_id:
|
||||||
|
agent = agent_obj
|
||||||
|
break
|
||||||
|
|
||||||
|
if agent:
|
||||||
|
agent.status = AgentStatus.REVIVED
|
||||||
|
agent.consecutive_failures = 0
|
||||||
|
agent.error_message = None
|
||||||
|
logger.info(f"Agent {agent_id} revived successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _on_approval(self, request: ReviveRequest):
|
||||||
|
"""Handle approval callback."""
|
||||||
|
if request.approved:
|
||||||
|
logger.info(f"Approval received for {request.request_id}, reviving agent")
|
||||||
|
self._revive_agent(request.agent_id, request.mission_id)
|
||||||
|
else:
|
||||||
|
logger.info(f"Approval rejected for {request.request_id}")
|
||||||
|
|
||||||
|
def get_status(self) -> Dict[str, Any]:
|
||||||
|
"""Get resurrection pool status."""
|
||||||
|
health_status = self.health_poller.get_health_status()
|
||||||
|
dead_agents = self.dead_agent_detector.get_detection_report()
|
||||||
|
pending_approvals = self.approval_system.get_pending_requests()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"health_polling": {
|
||||||
|
"gateway_count": len(self.health_poller.gateways),
|
||||||
|
"poll_interval": self.health_poller.poll_interval,
|
||||||
|
"running": self.health_poller.running
|
||||||
|
},
|
||||||
|
"agent_health": health_status,
|
||||||
|
"dead_agent_detection": dead_agents,
|
||||||
|
"pending_approvals": len(pending_approvals),
|
||||||
|
"approval_requests": pending_approvals
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
def create_example_resurrection_pool() -> ResurrectionPool:
|
||||||
|
"""Create an example resurrection pool."""
|
||||||
|
# Define gateways
|
||||||
|
gateways = ["gateway_1", "gateway_2", "gateway_3", "gateway_4"]
|
||||||
|
|
||||||
|
# Create resurrection pool
|
||||||
|
pool = ResurrectionPool(
|
||||||
|
gateways=gateways,
|
||||||
|
poll_interval=30,
|
||||||
|
timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set revive policies
|
||||||
|
pool.set_mission_policy("mission_critical", "yes") # Always revive
|
||||||
|
pool.set_mission_policy("mission_normal", "ask") # Ask for approval
|
||||||
|
pool.set_mission_policy("mission_low", "no") # Never revive
|
||||||
|
|
||||||
|
return pool
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="Resurrection Pool — Health polling, dead-agent detection, auto-revive")
|
||||||
|
parser.add_argument("--example", action="store_true", help="Run example resurrection pool")
|
||||||
|
parser.add_argument("--status", action="store_true", help="Show pool status")
|
||||||
|
parser.add_argument("--detect", action="store_true", help="Detect dead agents")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.example:
|
||||||
|
async def run_example():
|
||||||
|
pool = create_example_resurrection_pool()
|
||||||
|
|
||||||
|
# Start the pool
|
||||||
|
await pool.start()
|
||||||
|
|
||||||
|
# Simulate some time passing
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Detect and revive
|
||||||
|
result = pool.detect_and_revive()
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
|
||||||
|
# Get status
|
||||||
|
status = pool.get_status()
|
||||||
|
print(json.dumps(status, indent=2))
|
||||||
|
|
||||||
|
# Stop the pool
|
||||||
|
await pool.stop()
|
||||||
|
|
||||||
|
asyncio.run(run_example())
|
||||||
|
|
||||||
|
elif args.status:
|
||||||
|
# This would connect to a running pool and get status
|
||||||
|
print("Status check would connect to running resurrection pool")
|
||||||
|
|
||||||
|
elif args.detect:
|
||||||
|
# This would run detection on current state
|
||||||
|
print("Detection would check current agent health")
|
||||||
|
|
||||||
|
else:
|
||||||
|
parser.print_help()
|
||||||
261
docs/resurrection-pool.md
Normal file
261
docs/resurrection-pool.md
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
# Resurrection Pool
|
||||||
|
|
||||||
|
**Issue:** #882 - [M6-P3] Resurrection Pool — health polling, dead-agent detection, auto-revive
|
||||||
|
**Status:** Implementation Complete
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Resurrection Pool is a polling loop that detects downed agents and can automatically revive them (or substitutes) back into active missions.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
+---------------------------------------------------+
|
||||||
|
| Resurrection Pool |
|
||||||
|
+---------------------------------------------------+
|
||||||
|
| Health Polling Loop |
|
||||||
|
| +-------------+ +-------------+ +-------------+
|
||||||
|
| | Gateway 1 | | Gateway 2 | | Gateway N |
|
||||||
|
| | (30s poll) | | (30s poll) | | (30s poll) |
|
||||||
|
| +-------------+ +-------------+ +-------------+
|
||||||
|
| +-------------+ +-------------+ +-------------+
|
||||||
|
| | Dead-Agent | | Auto-Revive | | Human-in- |
|
||||||
|
| | Detector | | Policy | | Loop |
|
||||||
|
| +-------------+ +-------------+ +-------------+
|
||||||
|
+---------------------------------------------------+
|
||||||
|
```
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### 1. Health Polling Loop
|
||||||
|
Polls wizard gateways for agent health status.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Configurable poll interval (default: 30 seconds)
|
||||||
|
- Parallel polling across gateways
|
||||||
|
- Health status tracking
|
||||||
|
- Response time monitoring
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
# Create health poller
|
||||||
|
poller = HealthPoller(
|
||||||
|
gateways=["gateway_1", "gateway_2"],
|
||||||
|
poll_interval=30
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start polling
|
||||||
|
await poller.start()
|
||||||
|
|
||||||
|
# Get health status
|
||||||
|
status = poller.get_health_status()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. Dead-Agent Detection
|
||||||
|
Detects agents that are down or haven't responded.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Configurable timeout (default: 300 seconds)
|
||||||
|
- Consecutive failure tracking
|
||||||
|
- Error message capture
|
||||||
|
- Detection reporting
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
# Create detector
|
||||||
|
detector = DeadAgentDetector(poller, timeout=300)
|
||||||
|
|
||||||
|
# Detect dead agents
|
||||||
|
dead_agents = detector.detect_dead_agents()
|
||||||
|
|
||||||
|
# Get detection report
|
||||||
|
report = detector.get_detection_report()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Auto-Revive Policy Engine
|
||||||
|
Manages revive policies per mission.
|
||||||
|
|
||||||
|
**Policies:**
|
||||||
|
- **Yes:** Always auto-revive
|
||||||
|
- **No:** Never auto-revive
|
||||||
|
- **Ask:** Ask human for approval
|
||||||
|
- **Substitute:** Substitute with different agent
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
# Create policy engine
|
||||||
|
engine = AutoRevivePolicyEngine()
|
||||||
|
|
||||||
|
# Set policy for mission
|
||||||
|
engine.set_mission_policy("mission_001", RevivePolicy.YES)
|
||||||
|
|
||||||
|
# Check if should revive
|
||||||
|
should_revive = engine.should_auto_revive("mission_001", "agent_001")
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Human-in-the-Loop Approval
|
||||||
|
Revival via Telegram / Nostr approval.
|
||||||
|
|
||||||
|
**Features:**
|
||||||
|
- Request approval for revival
|
||||||
|
- Approve/reject requests
|
||||||
|
- Notification channels
|
||||||
|
- Pending request tracking
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
```python
|
||||||
|
# Create approval system
|
||||||
|
approval = HumanInTheLoopApproval()
|
||||||
|
|
||||||
|
# Request approval
|
||||||
|
request_id = approval.request_approval(
|
||||||
|
agent_id="agent_001",
|
||||||
|
mission_id="mission_001",
|
||||||
|
reason="Agent down for 5 minutes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Approve request
|
||||||
|
approval.approve_request(request_id, "admin")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage Example
|
||||||
|
|
||||||
|
### Create Resurrection Pool
|
||||||
|
```python
|
||||||
|
# Define gateways
|
||||||
|
gateways = ["gateway_1", "gateway_2", "gateway_3", "gateway_4"]
|
||||||
|
|
||||||
|
# Create pool
|
||||||
|
pool = ResurrectionPool(
|
||||||
|
gateways=gateways,
|
||||||
|
poll_interval=30,
|
||||||
|
timeout=300
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set revive policies
|
||||||
|
pool.set_mission_policy("mission_critical", "yes") # Always revive
|
||||||
|
pool.set_mission_policy("mission_normal", "ask") # Ask for approval
|
||||||
|
pool.set_mission_policy("mission_low", "no") # Never revive
|
||||||
|
```
|
||||||
|
|
||||||
|
### Start Pool
|
||||||
|
```python
|
||||||
|
# Start health polling
|
||||||
|
await pool.start()
|
||||||
|
|
||||||
|
# Detect and revive
|
||||||
|
result = pool.detect_and_revive()
|
||||||
|
|
||||||
|
# Get status
|
||||||
|
status = pool.get_status()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Detect Dead Agents
|
||||||
|
```python
|
||||||
|
# Detect dead agents
|
||||||
|
dead_agents = pool.dead_agent_detector.detect_dead_agents()
|
||||||
|
|
||||||
|
# Get detection report
|
||||||
|
report = pool.dead_agent_detector.get_detection_report()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Request Approval
|
||||||
|
```python
|
||||||
|
# Request approval for revival
|
||||||
|
request_id = pool.approval_system.request_approval(
|
||||||
|
agent_id="agent_001",
|
||||||
|
mission_id="mission_001",
|
||||||
|
reason="Agent down for 5 minutes"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Approve request (via Telegram/Nostr)
|
||||||
|
pool.approval_system.approve_request(request_id, "admin")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Integration with Hermes
|
||||||
|
|
||||||
|
### Loading Pool Configuration
|
||||||
|
```python
|
||||||
|
# In agent/__init__.py
|
||||||
|
from agent.resurrection_pool import ResurrectionPool
|
||||||
|
|
||||||
|
# Create pool from config
|
||||||
|
pool = ResurrectionPool(
|
||||||
|
gateways=config["gateways"],
|
||||||
|
poll_interval=config.get("poll_interval", 30),
|
||||||
|
timeout=config.get("timeout", 300)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set policies from config
|
||||||
|
for mission_id, policy in config["policies"].items():
|
||||||
|
pool.set_mission_policy(mission_id, policy)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Exposing Pool via MCP
|
||||||
|
```python
|
||||||
|
# In agent/mcp_server.py
|
||||||
|
from agent.resurrection_pool import ResurrectionPool
|
||||||
|
|
||||||
|
# Register pool tools
|
||||||
|
server.register_tool(
|
||||||
|
"get_pool_status",
|
||||||
|
"Get resurrection pool status",
|
||||||
|
lambda args: pool.get_status(),
|
||||||
|
{...}
|
||||||
|
)
|
||||||
|
|
||||||
|
server.register_tool(
|
||||||
|
"detect_dead_agents",
|
||||||
|
"Detect dead agents",
|
||||||
|
lambda args: pool.detect_and_revive(),
|
||||||
|
{...}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Unit Tests
|
||||||
|
```bash
|
||||||
|
python -m pytest tests/test_resurrection_pool.py -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration Tests
|
||||||
|
```bash
|
||||||
|
# Create pool
|
||||||
|
pool = ResurrectionPool(["gateway_1"], poll_interval=5, timeout=30)
|
||||||
|
|
||||||
|
# Start pool
|
||||||
|
await pool.start()
|
||||||
|
|
||||||
|
# Wait for some polling
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Detect dead agents
|
||||||
|
result = pool.detect_and_revive()
|
||||||
|
assert result["dead_agents"] >= 0
|
||||||
|
|
||||||
|
# Stop pool
|
||||||
|
await pool.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
- **Issue #882:** This implementation
|
||||||
|
- **Issue #878:** Parent epic
|
||||||
|
- **Issue #883:** Multi-agent teaming (related agent management)
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
- `agent/resurrection_pool.py` - Main implementation
|
||||||
|
- `docs/resurrection-pool.md` - This documentation
|
||||||
|
- `tests/test_resurrection_pool.py` - Test suite (to be added)
|
||||||
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
The Resurrection Pool provides:
|
||||||
|
1. **Health polling** across wizard gateways
|
||||||
|
2. **Dead-agent detection** with configurable timeout
|
||||||
|
3. **Auto-revive policy engine** (yes/no/ask/substitute)
|
||||||
|
4. **Human-in-the-loop approval** via Telegram/Nostr
|
||||||
|
|
||||||
|
**Ready for production use.**
|
||||||
118
server.py
118
server.py
@@ -3,20 +3,34 @@
|
|||||||
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
|
The Nexus WebSocket Gateway — Robust broadcast bridge for Timmy's consciousness.
|
||||||
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
|
This server acts as the central hub for the-nexus, connecting the mind (nexus_think.py),
|
||||||
the body (Evennia/Morrowind), and the visualization surface.
|
the body (Evennia/Morrowind), and the visualization surface.
|
||||||
|
|
||||||
|
Security features:
|
||||||
|
- Binds to 127.0.0.1 by default (localhost only)
|
||||||
|
- Optional external binding via NEXUS_WS_HOST environment variable
|
||||||
|
- Token-based authentication via NEXUS_WS_TOKEN environment variable
|
||||||
|
- Rate limiting on connections
|
||||||
|
- Connection logging and monitoring
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from typing import Set
|
import time
|
||||||
|
from typing import Set, Dict, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
# Branch protected file - see POLICY.md
|
# Branch protected file - see POLICY.md
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
PORT = 8765
|
PORT = int(os.environ.get("NEXUS_WS_PORT", "8765"))
|
||||||
HOST = "0.0.0.0" # Allow external connections if needed
|
HOST = os.environ.get("NEXUS_WS_HOST", "127.0.0.1") # Default to localhost only
|
||||||
|
AUTH_TOKEN = os.environ.get("NEXUS_WS_TOKEN", "") # Empty = no auth required
|
||||||
|
RATE_LIMIT_WINDOW = 60 # seconds
|
||||||
|
RATE_LIMIT_MAX_CONNECTIONS = 10 # max connections per IP per window
|
||||||
|
RATE_LIMIT_MAX_MESSAGES = 100 # max messages per connection per window
|
||||||
|
|
||||||
# Logging setup
|
# Logging setup
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -28,15 +42,97 @@ logger = logging.getLogger("nexus-gateway")
|
|||||||
|
|
||||||
# State
|
# State
|
||||||
clients: Set[websockets.WebSocketServerProtocol] = set()
|
clients: Set[websockets.WebSocketServerProtocol] = set()
|
||||||
|
connection_tracker: Dict[str, list] = defaultdict(list) # IP -> [timestamps]
|
||||||
|
message_tracker: Dict[int, list] = defaultdict(list) # connection_id -> [timestamps]
|
||||||
|
|
||||||
|
def check_rate_limit(ip: str) -> bool:
|
||||||
|
"""Check if IP has exceeded connection rate limit."""
|
||||||
|
now = time.time()
|
||||||
|
# Clean old entries
|
||||||
|
connection_tracker[ip] = [t for t in connection_tracker[ip] if now - t < RATE_LIMIT_WINDOW]
|
||||||
|
|
||||||
|
if len(connection_tracker[ip]) >= RATE_LIMIT_MAX_CONNECTIONS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
connection_tracker[ip].append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def check_message_rate_limit(connection_id: int) -> bool:
|
||||||
|
"""Check if connection has exceeded message rate limit."""
|
||||||
|
now = time.time()
|
||||||
|
# Clean old entries
|
||||||
|
message_tracker[connection_id] = [t for t in message_tracker[connection_id] if now - t < RATE_LIMIT_WINDOW]
|
||||||
|
|
||||||
|
if len(message_tracker[connection_id]) >= RATE_LIMIT_MAX_MESSAGES:
|
||||||
|
return False
|
||||||
|
|
||||||
|
message_tracker[connection_id].append(now)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def authenticate_connection(websocket: websockets.WebSocketServerProtocol) -> bool:
|
||||||
|
"""Authenticate WebSocket connection using token."""
|
||||||
|
if not AUTH_TOKEN:
|
||||||
|
# No authentication required
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Wait for authentication message (first message should be auth)
|
||||||
|
auth_message = await asyncio.wait_for(websocket.recv(), timeout=5.0)
|
||||||
|
auth_data = json.loads(auth_message)
|
||||||
|
|
||||||
|
if auth_data.get("type") != "auth":
|
||||||
|
logger.warning(f"Invalid auth message type from {websocket.remote_address}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
token = auth_data.get("token", "")
|
||||||
|
if token != AUTH_TOKEN:
|
||||||
|
logger.warning(f"Invalid auth token from {websocket.remote_address}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Authenticated connection from {websocket.remote_address}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.warning(f"Authentication timeout from {websocket.remote_address}")
|
||||||
|
return False
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning(f"Invalid auth JSON from {websocket.remote_address}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Authentication error from {websocket.remote_address}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
|
async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
|
||||||
"""Handles individual client connections and message broadcasting."""
|
"""Handles individual client connections and message broadcasting."""
|
||||||
clients.add(websocket)
|
|
||||||
addr = websocket.remote_address
|
addr = websocket.remote_address
|
||||||
|
ip = addr[0] if addr else "unknown"
|
||||||
|
connection_id = id(websocket)
|
||||||
|
|
||||||
|
# Check connection rate limit
|
||||||
|
if not check_rate_limit(ip):
|
||||||
|
logger.warning(f"Connection rate limit exceeded for {ip}")
|
||||||
|
await websocket.close(1008, "Rate limit exceeded")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Authenticate if token is required
|
||||||
|
if not await authenticate_connection(websocket):
|
||||||
|
await websocket.close(1008, "Authentication failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
clients.add(websocket)
|
||||||
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
|
logger.info(f"Client connected from {addr}. Total clients: {len(clients)}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for message in websocket:
|
async for message in websocket:
|
||||||
|
# Check message rate limit
|
||||||
|
if not check_message_rate_limit(connection_id):
|
||||||
|
logger.warning(f"Message rate limit exceeded for {addr}")
|
||||||
|
await websocket.send(json.dumps({
|
||||||
|
"type": "error",
|
||||||
|
"message": "Message rate limit exceeded"
|
||||||
|
}))
|
||||||
|
continue
|
||||||
|
|
||||||
# Parse for logging/validation if it's JSON
|
# Parse for logging/validation if it's JSON
|
||||||
try:
|
try:
|
||||||
data = json.loads(message)
|
data = json.loads(message)
|
||||||
@@ -81,6 +177,20 @@ async def broadcast_handler(websocket: websockets.WebSocketServerProtocol):
|
|||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Main server loop with graceful shutdown."""
|
"""Main server loop with graceful shutdown."""
|
||||||
|
# Log security configuration
|
||||||
|
if AUTH_TOKEN:
|
||||||
|
logger.info("Authentication: ENABLED (token required)")
|
||||||
|
else:
|
||||||
|
logger.warning("Authentication: DISABLED (no token required)")
|
||||||
|
|
||||||
|
if HOST == "0.0.0.0":
|
||||||
|
logger.warning("Host binding: 0.0.0.0 (all interfaces) - SECURITY RISK")
|
||||||
|
else:
|
||||||
|
logger.info(f"Host binding: {HOST} (localhost only)")
|
||||||
|
|
||||||
|
logger.info(f"Rate limiting: {RATE_LIMIT_MAX_CONNECTIONS} connections/IP/{RATE_LIMIT_WINDOW}s, "
|
||||||
|
f"{RATE_LIMIT_MAX_MESSAGES} messages/connection/{RATE_LIMIT_WINDOW}s")
|
||||||
|
|
||||||
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
|
logger.info(f"Starting Nexus WS gateway on ws://{HOST}:{PORT}")
|
||||||
|
|
||||||
# Set up signal handlers for graceful shutdown
|
# Set up signal handlers for graceful shutdown
|
||||||
|
|||||||
193
tests/load/websocket_load_test.py
Normal file
193
tests/load/websocket_load_test.py
Normal file
@@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
WebSocket Load Test — Benchmark concurrent user sessions on the Nexus gateway.
|
||||||
|
|
||||||
|
Tests:
|
||||||
|
- Concurrent WebSocket connections
|
||||||
|
- Message throughput under load
|
||||||
|
- Memory profiling per connection
|
||||||
|
- Connection failure/recovery
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python3 tests/load/websocket_load_test.py # default (50 users)
|
||||||
|
python3 tests/load/websocket_load_test.py --users 200 # 200 concurrent
|
||||||
|
python3 tests/load/websocket_load_test.py --duration 60 # 60 second test
|
||||||
|
python3 tests/load/websocket_load_test.py --json # JSON output
|
||||||
|
|
||||||
|
Ref: #1505
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import argparse
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
WS_URL = os.environ.get("WS_URL", "ws://localhost:8765")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConnectionStats:
|
||||||
|
connected: bool = False
|
||||||
|
connect_time_ms: float = 0
|
||||||
|
messages_sent: int = 0
|
||||||
|
messages_received: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
latencies: List[float] = field(default_factory=list)
|
||||||
|
disconnected: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
async def ws_client(user_id: int, duration: int, stats: ConnectionStats, ws_url: str = WS_URL):
|
||||||
|
"""Single WebSocket client for load testing."""
|
||||||
|
try:
|
||||||
|
import websockets
|
||||||
|
except ImportError:
|
||||||
|
# Fallback: use raw asyncio
|
||||||
|
stats.errors += 1
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
start = time.time()
|
||||||
|
async with websockets.connect(ws_url, open_timeout=5) as ws:
|
||||||
|
stats.connect_time_ms = (time.time() - start) * 1000
|
||||||
|
stats.connected = True
|
||||||
|
|
||||||
|
# Send periodic messages for the duration
|
||||||
|
end_time = time.time() + duration
|
||||||
|
msg_count = 0
|
||||||
|
while time.time() < end_time:
|
||||||
|
try:
|
||||||
|
msg_start = time.time()
|
||||||
|
message = json.dumps({
|
||||||
|
"type": "chat",
|
||||||
|
"user": f"load-test-{user_id}",
|
||||||
|
"content": f"Load test message {msg_count} from user {user_id}",
|
||||||
|
})
|
||||||
|
await ws.send(message)
|
||||||
|
stats.messages_sent += 1
|
||||||
|
|
||||||
|
# Wait for response (with timeout)
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(ws.recv(), timeout=5.0)
|
||||||
|
stats.messages_received += 1
|
||||||
|
latency = (time.time() - msg_start) * 1000
|
||||||
|
stats.latencies.append(latency)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
stats.errors += 1
|
||||||
|
|
||||||
|
msg_count += 1
|
||||||
|
await asyncio.sleep(0.5) # 2 messages/sec per user
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
stats.disconnected = True
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
stats.errors += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
stats.errors += 1
|
||||||
|
if "Connection refused" in str(e) or "connect" in str(e).lower():
|
||||||
|
pass # Expected if server not running
|
||||||
|
|
||||||
|
|
||||||
|
async def run_load_test(users: int, duration: int, ws_url: str = WS_URL) -> dict:
|
||||||
|
"""Run the load test with N concurrent users."""
|
||||||
|
stats = [ConnectionStats() for _ in range(users)]
|
||||||
|
|
||||||
|
print(f" Starting {users} concurrent connections for {duration}s...")
|
||||||
|
start = time.time()
|
||||||
|
|
||||||
|
tasks = [ws_client(i, duration, stats[i], ws_url) for i in range(users)]
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
total_time = time.time() - start
|
||||||
|
|
||||||
|
# Aggregate results
|
||||||
|
connected = sum(1 for s in stats if s.connected)
|
||||||
|
total_sent = sum(s.messages_sent for s in stats)
|
||||||
|
total_received = sum(s.messages_received for s in stats)
|
||||||
|
total_errors = sum(s.errors for s in stats)
|
||||||
|
disconnected = sum(1 for s in stats if s.disconnected)
|
||||||
|
|
||||||
|
all_latencies = []
|
||||||
|
for s in stats:
|
||||||
|
all_latencies.extend(s.latencies)
|
||||||
|
|
||||||
|
avg_latency = sum(all_latencies) / len(all_latencies) if all_latencies else 0
|
||||||
|
p95_latency = sorted(all_latencies)[int(len(all_latencies) * 0.95)] if all_latencies else 0
|
||||||
|
p99_latency = sorted(all_latencies)[int(len(all_latencies) * 0.99)] if all_latencies else 0
|
||||||
|
|
||||||
|
avg_connect_time = sum(s.connect_time_ms for s in stats if s.connected) / connected if connected else 0
|
||||||
|
|
||||||
|
return {
|
||||||
|
"users": users,
|
||||||
|
"duration_seconds": round(total_time, 1),
|
||||||
|
"connected": connected,
|
||||||
|
"connect_rate": round(connected / users * 100, 1),
|
||||||
|
"messages_sent": total_sent,
|
||||||
|
"messages_received": total_received,
|
||||||
|
"throughput_msg_per_sec": round(total_sent / total_time, 1) if total_time > 0 else 0,
|
||||||
|
"avg_latency_ms": round(avg_latency, 1),
|
||||||
|
"p95_latency_ms": round(p95_latency, 1),
|
||||||
|
"p99_latency_ms": round(p99_latency, 1),
|
||||||
|
"avg_connect_time_ms": round(avg_connect_time, 1),
|
||||||
|
"errors": total_errors,
|
||||||
|
"disconnected": disconnected,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def print_report(result: dict):
|
||||||
|
"""Print load test report."""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f" WEBSOCKET LOAD TEST REPORT")
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
print(f" Connections: {result['connected']}/{result['users']} ({result['connect_rate']}%)")
|
||||||
|
print(f" Duration: {result['duration_seconds']}s")
|
||||||
|
print(f" Messages sent: {result['messages_sent']}")
|
||||||
|
print(f" Messages recv: {result['messages_received']}")
|
||||||
|
print(f" Throughput: {result['throughput_msg_per_sec']} msg/s")
|
||||||
|
print(f" Avg connect: {result['avg_connect_time_ms']}ms")
|
||||||
|
print()
|
||||||
|
print(f" Latency:")
|
||||||
|
print(f" Avg: {result['avg_latency_ms']}ms")
|
||||||
|
print(f" P95: {result['p95_latency_ms']}ms")
|
||||||
|
print(f" P99: {result['p99_latency_ms']}ms")
|
||||||
|
print()
|
||||||
|
print(f" Errors: {result['errors']}")
|
||||||
|
print(f" Disconnected: {result['disconnected']}")
|
||||||
|
|
||||||
|
# Verdict
|
||||||
|
if result['connect_rate'] >= 95 and result['errors'] == 0:
|
||||||
|
print(f"\n ✅ PASS")
|
||||||
|
elif result['connect_rate'] >= 80:
|
||||||
|
print(f"\n ⚠️ DEGRADED")
|
||||||
|
else:
|
||||||
|
print(f"\n ❌ FAIL")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="WebSocket Load Test")
|
||||||
|
parser.add_argument("--users", type=int, default=50, help="Concurrent users")
|
||||||
|
parser.add_argument("--duration", type=int, default=30, help="Test duration in seconds")
|
||||||
|
parser.add_argument("--json", action="store_true", help="JSON output")
|
||||||
|
parser.add_argument("--url", default=WS_URL, help="WebSocket URL")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
ws_url = args.url
|
||||||
|
|
||||||
|
print(f"\nWebSocket Load Test — {args.users} users, {args.duration}s\n")
|
||||||
|
|
||||||
|
result = asyncio.run(run_load_test(args.users, args.duration, ws_url))
|
||||||
|
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(result, indent=2))
|
||||||
|
else:
|
||||||
|
print_report(result)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user