Compare commits
1 Commits
burn/350-1
...
burn/328-1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c71f95daa2 |
@@ -1,18 +0,0 @@
|
||||
{
|
||||
"agents": {
|
||||
"ezra": {
|
||||
"host": "143.198.27.163",
|
||||
"hermes_path": "/root/wizards/ezra/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
},
|
||||
"timmy": {
|
||||
"host": "timmy",
|
||||
"hermes_path": "/root/wizards/timmy/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
}
|
||||
},
|
||||
"validation_timeout": 30,
|
||||
"command_timeout": 300,
|
||||
"max_retries": 2,
|
||||
"retry_delay": 5
|
||||
}
|
||||
@@ -1,551 +0,0 @@
|
||||
"""
|
||||
VPS Agent Dispatch Worker for Hermes Cron System
|
||||
|
||||
This module provides a dispatch worker that SSHs into remote VPS machines
|
||||
and runs hermes commands. It ensures that:
|
||||
|
||||
1. Remote dispatch only counts as success when the remote hermes command actually launches
|
||||
2. Stale per-agent hermes binary paths are configurable/validated before queue drain
|
||||
3. Failed remote launches remain in the queue (or are marked failed) instead of being reported as OK
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DispatchStatus(Enum):
|
||||
"""Status of a dispatch operation."""
|
||||
PENDING = "pending"
|
||||
VALIDATING = "validating"
|
||||
DISPATCHING = "dispatching"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
RETRYING = "retrying"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchResult:
|
||||
"""Result of a dispatch operation."""
|
||||
status: DispatchStatus
|
||||
message: str
|
||||
exit_code: Optional[int] = None
|
||||
stdout: Optional[str] = None
|
||||
stderr: Optional[str] = None
|
||||
execution_time: Optional[float] = None
|
||||
hermes_path: Optional[str] = None
|
||||
validated: bool = False
|
||||
|
||||
|
||||
class HermesPathValidator:
|
||||
"""Validates hermes binary paths on remote VPS machines."""
|
||||
|
||||
def __init__(self, ssh_key_path: Optional[str] = None):
|
||||
self.ssh_key_path = ssh_key_path or os.path.expanduser("~/.ssh/id_rsa")
|
||||
self.timeout = 30 # SSH timeout in seconds
|
||||
|
||||
def validate_hermes_path(self, host: str, hermes_path: str,
|
||||
username: str = "root") -> DispatchResult:
|
||||
"""
|
||||
Validate that the hermes binary exists and is executable on the remote host.
|
||||
|
||||
Args:
|
||||
host: Remote host IP or hostname
|
||||
hermes_path: Path to hermes binary on remote host
|
||||
username: SSH username
|
||||
|
||||
Returns:
|
||||
DispatchResult with validation status
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Build SSH command to check hermes binary
|
||||
ssh_cmd = [
|
||||
"ssh",
|
||||
"-i", self.ssh_key_path,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
f"{username}@{host}",
|
||||
f"test -x {hermes_path} && echo 'VALID' || echo 'INVALID'"
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0 and "VALID" in result.stdout:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.SUCCESS,
|
||||
message=f"Hermes binary validated at {hermes_path}",
|
||||
exit_code=0,
|
||||
execution_time=execution_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=True
|
||||
)
|
||||
else:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Hermes binary not found or not executable: {hermes_path}",
|
||||
exit_code=result.returncode,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"SSH timeout validating hermes path on {host}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
except Exception as e:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Error validating hermes path: {str(e)}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
|
||||
|
||||
class VPSAgentDispatcher:
|
||||
"""Dispatches hermes commands to remote VPS agents."""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config_path = config_path or os.path.expanduser("~/.hermes/dispatch_config.json")
|
||||
self.validator = HermesPathValidator()
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load dispatch configuration."""
|
||||
try:
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load dispatch config: {e}")
|
||||
|
||||
# Default configuration
|
||||
return {
|
||||
"agents": {
|
||||
"ezra": {
|
||||
"host": "143.198.27.163",
|
||||
"hermes_path": "/root/wizards/ezra/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
},
|
||||
"timmy": {
|
||||
"host": "timmy",
|
||||
"hermes_path": "/root/wizards/timmy/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
}
|
||||
},
|
||||
"validation_timeout": 30,
|
||||
"command_timeout": 300,
|
||||
"max_retries": 2,
|
||||
"retry_delay": 5
|
||||
}
|
||||
|
||||
def save_config(self):
|
||||
"""Save dispatch configuration."""
|
||||
try:
|
||||
config_dir = Path(self.config_path).parent
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.config_path, 'w') as f:
|
||||
json.dump(self.config, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(self.config_path, 0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dispatch config: {e}")
|
||||
|
||||
def get_agent_config(self, agent_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get configuration for a specific agent."""
|
||||
return self.config.get("agents", {}).get(agent_name)
|
||||
|
||||
def update_agent_config(self, agent_name: str, host: str, hermes_path: str,
|
||||
username: str = "root"):
|
||||
"""Update configuration for a specific agent."""
|
||||
if "agents" not in self.config:
|
||||
self.config["agents"] = {}
|
||||
|
||||
self.config["agents"][agent_name] = {
|
||||
"host": host,
|
||||
"hermes_path": hermes_path,
|
||||
"username": username
|
||||
}
|
||||
|
||||
self.save_config()
|
||||
|
||||
def validate_agent(self, agent_name: str) -> DispatchResult:
|
||||
"""Validate that an agent's hermes binary is accessible."""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
return self.validator.validate_hermes_path(
|
||||
host=agent_config["host"],
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
username=agent_config.get("username", "root")
|
||||
)
|
||||
|
||||
def dispatch_command(self, agent_name: str, command: str,
|
||||
validate_first: bool = True) -> DispatchResult:
|
||||
"""
|
||||
Dispatch a command to a remote VPS agent.
|
||||
|
||||
Args:
|
||||
agent_name: Name of the agent to dispatch to
|
||||
command: Command to execute
|
||||
validate_first: Whether to validate hermes path before dispatching
|
||||
|
||||
Returns:
|
||||
DispatchResult with execution status
|
||||
"""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
# Validate hermes path if requested
|
||||
if validate_first:
|
||||
validation_result = self.validate_agent(agent_name)
|
||||
if validation_result.status != DispatchStatus.SUCCESS:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Validation failed: {validation_result.message}",
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=False
|
||||
)
|
||||
|
||||
# Build SSH command to execute hermes command
|
||||
ssh_cmd = [
|
||||
"ssh",
|
||||
"-i", self.validator.ssh_key_path,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "ConnectTimeout=10",
|
||||
f"{agent_config.get('username', 'root')}@{agent_config['host']}",
|
||||
f"cd /root/wizards/{agent_name}/hermes-agent && source venv/bin/activate && {command}"
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.config.get("command_timeout", 300)
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.SUCCESS,
|
||||
message=f"Command executed successfully on {agent_name}",
|
||||
exit_code=0,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
else:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Command failed on {agent_name}: {result.stderr}",
|
||||
exit_code=result.returncode,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Command timeout on {agent_name}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
except Exception as e:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Error executing command on {agent_name}: {str(e)}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
|
||||
def dispatch_hermes_command(self, agent_name: str, hermes_command: str,
|
||||
validate_first: bool = True) -> DispatchResult:
|
||||
"""
|
||||
Dispatch a hermes command to a remote VPS agent.
|
||||
|
||||
Args:
|
||||
agent_name: Name of the agent to dispatch to
|
||||
hermes_command: Hermes command to execute (e.g., "hermes cron list")
|
||||
validate_first: Whether to validate hermes path before dispatching
|
||||
|
||||
Returns:
|
||||
DispatchResult with execution status
|
||||
"""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
# Build full hermes command
|
||||
full_command = f"{agent_config['hermes_path']} {hermes_command}"
|
||||
|
||||
return self.dispatch_command(agent_name, full_command, validate_first)
|
||||
|
||||
|
||||
class DispatchQueue:
|
||||
"""Queue for managing dispatch operations."""
|
||||
|
||||
def __init__(self, queue_file: Optional[str] = None):
|
||||
self.queue_file = queue_file or os.path.expanduser("~/.hermes/dispatch_queue.json")
|
||||
self.queue: List[Dict[str, Any]] = self._load_queue()
|
||||
|
||||
def _load_queue(self) -> List[Dict[str, Any]]:
|
||||
"""Load queue from file."""
|
||||
try:
|
||||
if os.path.exists(self.queue_file):
|
||||
with open(self.queue_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load dispatch queue: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def save_queue(self):
|
||||
"""Save queue to file."""
|
||||
try:
|
||||
queue_dir = Path(self.queue_file).parent
|
||||
queue_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.queue_file, 'w') as f:
|
||||
json.dump(self.queue, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(self.queue_file, 0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dispatch queue: {e}")
|
||||
|
||||
def add_item(self, agent_name: str, command: str, priority: int = 0,
|
||||
max_retries: int = 3) -> str:
|
||||
"""
|
||||
Add an item to the dispatch queue.
|
||||
|
||||
Returns:
|
||||
Queue item ID
|
||||
"""
|
||||
item_id = f"dispatch_{int(time.time())}_{len(self.queue)}"
|
||||
|
||||
item = {
|
||||
"id": item_id,
|
||||
"agent_name": agent_name,
|
||||
"command": command,
|
||||
"priority": priority,
|
||||
"max_retries": max_retries,
|
||||
"retry_count": 0,
|
||||
"status": DispatchStatus.PENDING.value,
|
||||
"created_at": time.time(),
|
||||
"last_attempt": None,
|
||||
"result": None
|
||||
}
|
||||
|
||||
self.queue.append(item)
|
||||
self.save_queue()
|
||||
|
||||
return item_id
|
||||
|
||||
def get_next_item(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the next item from the queue (highest priority, oldest first)."""
|
||||
if not self.queue:
|
||||
return None
|
||||
|
||||
# Sort by priority (descending) and created_at (ascending)
|
||||
sorted_queue = sorted(
|
||||
self.queue,
|
||||
key=lambda x: (-x.get("priority", 0), x.get("created_at", 0))
|
||||
)
|
||||
|
||||
# Find first pending item
|
||||
for item in sorted_queue:
|
||||
if item.get("status") == DispatchStatus.PENDING.value:
|
||||
return item
|
||||
|
||||
return None
|
||||
|
||||
def update_item(self, item_id: str, status: DispatchStatus,
|
||||
result: Optional[DispatchResult] = None):
|
||||
"""Update a queue item."""
|
||||
for item in self.queue:
|
||||
if item.get("id") == item_id:
|
||||
item["status"] = status.value
|
||||
item["last_attempt"] = time.time()
|
||||
|
||||
if result:
|
||||
item["result"] = {
|
||||
"status": result.status.value,
|
||||
"message": result.message,
|
||||
"exit_code": result.exit_code,
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"execution_time": result.execution_time,
|
||||
"hermes_path": result.hermes_path,
|
||||
"validated": result.validated
|
||||
}
|
||||
|
||||
# Update retry count if failed
|
||||
if status == DispatchStatus.FAILED:
|
||||
item["retry_count"] = item.get("retry_count", 0) + 1
|
||||
|
||||
self.save_queue()
|
||||
break
|
||||
|
||||
def remove_item(self, item_id: str):
|
||||
"""Remove an item from the queue."""
|
||||
self.queue = [item for item in self.queue if item.get("id") != item_id]
|
||||
self.save_queue()
|
||||
|
||||
def get_failed_items(self) -> List[Dict[str, Any]]:
|
||||
"""Get all failed items that can be retried."""
|
||||
return [
|
||||
item for item in self.queue
|
||||
if item.get("status") == DispatchStatus.FAILED.value
|
||||
and item.get("retry_count", 0) < item.get("max_retries", 3)
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get queue statistics."""
|
||||
total = len(self.queue)
|
||||
pending = sum(1 for item in self.queue if item.get("status") == DispatchStatus.PENDING.value)
|
||||
success = sum(1 for item in self.queue if item.get("status") == DispatchStatus.SUCCESS.value)
|
||||
failed = sum(1 for item in self.queue if item.get("status") == DispatchStatus.FAILED.value)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"pending": pending,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"retryable": len(self.get_failed_items())
|
||||
}
|
||||
|
||||
|
||||
def process_dispatch_queue(dispatcher: VPSAgentDispatcher,
|
||||
queue: DispatchQueue,
|
||||
batch_size: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Process items from the dispatch queue.
|
||||
|
||||
Args:
|
||||
dispatcher: VPS agent dispatcher
|
||||
queue: Dispatch queue
|
||||
batch_size: Number of items to process in this batch
|
||||
|
||||
Returns:
|
||||
Processing statistics
|
||||
"""
|
||||
processed = 0
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for _ in range(batch_size):
|
||||
item = queue.get_next_item()
|
||||
if not item:
|
||||
break
|
||||
|
||||
item_id = item["id"]
|
||||
agent_name = item["agent_name"]
|
||||
command = item["command"]
|
||||
|
||||
# Update status to dispatching
|
||||
queue.update_item(item_id, DispatchStatus.DISPATCHING)
|
||||
|
||||
# Dispatch the command
|
||||
result = dispatcher.dispatch_hermes_command(
|
||||
agent_name=agent_name,
|
||||
hermes_command=command,
|
||||
validate_first=True
|
||||
)
|
||||
|
||||
# Update queue with result
|
||||
if result.status == DispatchStatus.SUCCESS:
|
||||
queue.update_item(item_id, DispatchStatus.SUCCESS, result)
|
||||
success += 1
|
||||
else:
|
||||
# Check if we should retry
|
||||
item_data = next((i for i in queue.queue if i.get("id") == item_id), None)
|
||||
if item_data and item_data.get("retry_count", 0) < item_data.get("max_retries", 3):
|
||||
queue.update_item(item_id, DispatchStatus.FAILED, result)
|
||||
failed += 1
|
||||
else:
|
||||
# Max retries reached, remove from queue
|
||||
queue.remove_item(item_id)
|
||||
failed += 1
|
||||
|
||||
processed += 1
|
||||
|
||||
return {
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"queue_stats": queue.get_stats()
|
||||
}
|
||||
|
||||
|
||||
# Example usage and testing
|
||||
if __name__ == "__main__":
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create dispatcher and queue
|
||||
dispatcher = VPSAgentDispatcher()
|
||||
queue = DispatchQueue()
|
||||
|
||||
# Example: Add items to queue
|
||||
queue.add_item("ezra", "cron list")
|
||||
queue.add_item("timmy", "cron status")
|
||||
|
||||
# Process queue
|
||||
stats = process_dispatch_queue(dispatcher, queue)
|
||||
print(f"Processing stats: {stats}")
|
||||
|
||||
# Show queue stats
|
||||
queue_stats = queue.get_stats()
|
||||
print(f"Queue stats: {queue_stats}")
|
||||
@@ -653,12 +653,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
# AIAgent.__init__ is missing params the scheduler expects.
|
||||
_validate_agent_interface()
|
||||
|
||||
# Check if this is a dispatch job
|
||||
if job.get("type") == "dispatch" or "dispatch" in job.get("name", "").lower():
|
||||
return _run_dispatch_job(job)
|
||||
|
||||
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Initialize SQLite session store so cron job messages are persisted
|
||||
@@ -1013,89 +1007,6 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
logger.debug("Job '%s': failed to close SQLite session store: %s", job_id, e)
|
||||
|
||||
|
||||
|
||||
def _run_dispatch_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
"""
|
||||
Execute a dispatch job that SSHs into remote VPS machines.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, full_output_doc, final_response, error_message)
|
||||
"""
|
||||
from cron.dispatch_worker import VPSAgentDispatcher, DispatchQueue, process_dispatch_queue
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
|
||||
logger.info("Running dispatch job '%s' (ID: %s)", job_name, job_id)
|
||||
|
||||
try:
|
||||
# Load dispatch configuration
|
||||
dispatcher = VPSAgentDispatcher()
|
||||
queue = DispatchQueue()
|
||||
|
||||
# Get dispatch parameters from job
|
||||
agent_name = job.get("agent_name", "ezra")
|
||||
command = job.get("command", "cron list")
|
||||
batch_size = job.get("batch_size", 5)
|
||||
|
||||
# Add command to queue if specified
|
||||
if command:
|
||||
queue.add_item(agent_name, command)
|
||||
|
||||
# Process the dispatch queue
|
||||
stats = process_dispatch_queue(dispatcher, queue, batch_size)
|
||||
|
||||
# Generate output
|
||||
output = f"""# Dispatch Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Agent:** {agent_name}
|
||||
**Command:** {command}
|
||||
|
||||
## Dispatch Results
|
||||
|
||||
- **Processed:** {stats['processed']}
|
||||
- **Success:** {stats['success']}
|
||||
- **Failed:** {stats['failed']}
|
||||
|
||||
## Queue Statistics
|
||||
|
||||
- **Total items:** {stats['queue_stats']['total']}
|
||||
- **Pending:** {stats['queue_stats']['pending']}
|
||||
- **Success:** {stats['queue_stats']['success']}
|
||||
- **Failed:** {stats['queue_stats']['failed']}
|
||||
- **Retryable:** {stats['queue_stats']['retryable']}
|
||||
|
||||
## Status
|
||||
|
||||
{"✅ All dispatches successful" if stats['failed'] == 0 else f"⚠️ {stats['failed']} dispatches failed"}
|
||||
"""
|
||||
|
||||
success = stats['failed'] == 0
|
||||
error_message = None if success else f"{stats['failed']} dispatches failed"
|
||||
|
||||
return (success, output, output, error_message)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Dispatch job failed: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
output = f"""# Dispatch Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Status:** ❌ Failed
|
||||
|
||||
## Error
|
||||
|
||||
{error_msg}
|
||||
"""
|
||||
|
||||
return (False, output, output, error_msg)
|
||||
|
||||
|
||||
|
||||
def tick(verbose: bool = True, adapters=None, loop=None) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
@@ -127,6 +127,52 @@ class SessionResetPolicy:
|
||||
idle_minutes = data.get("idle_minutes")
|
||||
notify = data.get("notify")
|
||||
exclude = data.get("notify_exclude_platforms")
|
||||
|
||||
# Validate idle_minutes early — reject 0, negative, and absurdly large values
|
||||
if idle_minutes is not None:
|
||||
try:
|
||||
idle_minutes = int(idle_minutes)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Invalid idle_minutes=%r (not an integer). Using default 1440.",
|
||||
idle_minutes,
|
||||
)
|
||||
idle_minutes = None
|
||||
else:
|
||||
if idle_minutes <= 0:
|
||||
logger.warning(
|
||||
"Invalid idle_minutes=%s (must be positive). Using default 1440.",
|
||||
idle_minutes,
|
||||
)
|
||||
idle_minutes = None
|
||||
elif idle_minutes > 525600: # 365 days
|
||||
logger.warning(
|
||||
"idle_minutes=%s exceeds 1 year. Capping at 525600 (365 days).",
|
||||
idle_minutes,
|
||||
)
|
||||
idle_minutes = 525600
|
||||
|
||||
# Validate at_hour early
|
||||
if at_hour is not None:
|
||||
try:
|
||||
at_hour = int(at_hour)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning("Invalid at_hour=%r (not an integer). Using default 4.", at_hour)
|
||||
at_hour = None
|
||||
else:
|
||||
if not (0 <= at_hour <= 23):
|
||||
logger.warning("Invalid at_hour=%s (must be 0-23). Using default 4.", at_hour)
|
||||
at_hour = None
|
||||
|
||||
# Validate mode
|
||||
if mode is not None:
|
||||
mode = str(mode).strip().lower()
|
||||
if mode not in ("daily", "idle", "both", "none"):
|
||||
logger.warning(
|
||||
"Invalid session_reset mode=%r. Using default 'both'.", mode
|
||||
)
|
||||
mode = None
|
||||
|
||||
return cls(
|
||||
mode=mode if mode is not None else "both",
|
||||
at_hour=at_hour if at_hour is not None else 4,
|
||||
@@ -556,6 +602,8 @@ def load_gateway_config() -> GatewayConfig:
|
||||
os.environ["DISCORD_AUTO_THREAD"] = str(discord_cfg["auto_thread"]).lower()
|
||||
if "reactions" in discord_cfg and not os.getenv("DISCORD_REACTIONS"):
|
||||
os.environ["DISCORD_REACTIONS"] = str(discord_cfg["reactions"]).lower()
|
||||
if "skill_slash_commands" in discord_cfg and not os.getenv("DISCORD_SKILL_SLASH_COMMANDS"):
|
||||
os.environ["DISCORD_SKILL_SLASH_COMMANDS"] = str(discord_cfg["skill_slash_commands"]).lower()
|
||||
|
||||
# Telegram settings → env vars (env vars take precedence)
|
||||
telegram_cfg = yaml_cfg.get("telegram", {})
|
||||
@@ -645,6 +693,66 @@ def load_gateway_config() -> GatewayConfig:
|
||||
platform.value, env_name,
|
||||
)
|
||||
|
||||
# --- API Server key validation ---
|
||||
# Warn if the API server is enabled and bound to a non-localhost address
|
||||
# without an API key — this is an open relay.
|
||||
if Platform.API_SERVER in config.platforms and config.platforms[Platform.API_SERVER].enabled:
|
||||
api_cfg = config.platforms[Platform.API_SERVER]
|
||||
host = api_cfg.extra.get("host", os.getenv("API_SERVER_HOST", "127.0.0.1"))
|
||||
key = api_cfg.extra.get("key", os.getenv("API_SERVER_KEY", ""))
|
||||
if not key:
|
||||
if host in ("0.0.0.0", "::", ""):
|
||||
logger.error(
|
||||
"API server is bound to %s without API_SERVER_KEY set. "
|
||||
"This exposes an unauthenticated OpenAI-compatible endpoint to the network. "
|
||||
"Set API_SERVER_KEY immediately or bind to 127.0.0.1.",
|
||||
host,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"API server is enabled without API_SERVER_KEY. "
|
||||
"All requests will be unauthenticated. "
|
||||
"Set API_SERVER_KEY for production use.",
|
||||
)
|
||||
|
||||
# --- Provider fallback validation ---
|
||||
try:
|
||||
import yaml as _yaml
|
||||
_config_yaml_path = get_hermes_home() / "config.yaml"
|
||||
if _config_yaml_path.exists():
|
||||
with open(_config_yaml_path, encoding="utf-8") as _f:
|
||||
_raw_cfg = _yaml.safe_load(_f) or {}
|
||||
_fallback = _raw_cfg.get("fallback_model")
|
||||
if isinstance(_fallback, dict):
|
||||
_fb_provider = _fallback.get("provider", "")
|
||||
_fb_provider_lower = _fb_provider.lower().strip()
|
||||
if _fb_provider_lower == "openrouter" and not os.getenv("OPENROUTER_API_KEY"):
|
||||
logger.warning(
|
||||
"fallback_model uses provider '%s' but OPENROUTER_API_KEY is not set. "
|
||||
"Fallback will fail at runtime. Set OPENROUTER_API_KEY or change the fallback provider.",
|
||||
_fb_provider,
|
||||
)
|
||||
elif _fb_provider_lower in ("anthropic", "claude") and not os.getenv("ANTHROPIC_API_KEY"):
|
||||
logger.warning(
|
||||
"fallback_model uses provider '%s' but ANTHROPIC_API_KEY is not set. "
|
||||
"Fallback will fail at runtime.",
|
||||
_fb_provider,
|
||||
)
|
||||
elif _fb_provider_lower in ("openai",) and not os.getenv("OPENAI_API_KEY"):
|
||||
logger.warning(
|
||||
"fallback_model uses provider '%s' but OPENAI_API_KEY is not set. "
|
||||
"Fallback will fail at runtime.",
|
||||
_fb_provider,
|
||||
)
|
||||
elif _fb_provider_lower in ("nous", "nousresearch") and not os.getenv("NOUS_API_KEY"):
|
||||
logger.warning(
|
||||
"fallback_model uses provider '%s' but NOUS_API_KEY is not set. "
|
||||
"Fallback will fail at runtime.",
|
||||
_fb_provider,
|
||||
)
|
||||
except Exception:
|
||||
pass # best-effort validation
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@@ -667,6 +775,10 @@ _MIN_TOKEN_LENGTHS = {
|
||||
"DISCORD_BOT_TOKEN": 50,
|
||||
"SLACK_BOT_TOKEN": 20,
|
||||
"HASS_TOKEN": 20,
|
||||
"OPENROUTER_API_KEY": 20,
|
||||
"ANTHROPIC_API_KEY": 20,
|
||||
"OPENAI_API_KEY": 20,
|
||||
"NOUS_API_KEY": 20,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1623,6 +1623,19 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"[%s] API server listening on http://%s:%d",
|
||||
self.name, self._host, self._port,
|
||||
)
|
||||
if not self._api_key:
|
||||
if self._host in ("0.0.0.0", "::", ""):
|
||||
logger.error(
|
||||
"[%s] No API_SERVER_KEY set and bound to %s — "
|
||||
"endpoint is unauthenticated on the network. "
|
||||
"Set API_SERVER_KEY or bind to 127.0.0.1.",
|
||||
self.name, self._host,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"[%s] No API_SERVER_KEY set — all requests are unauthenticated.",
|
||||
self.name,
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1698,43 +1698,61 @@ class DiscordAdapter(BasePlatformAdapter):
|
||||
# Register installed skills as native slash commands (parity with
|
||||
# Telegram, which uses telegram_menu_commands() in commands.py).
|
||||
# Discord allows up to 100 application commands globally.
|
||||
_DISCORD_CMD_LIMIT = 100
|
||||
try:
|
||||
from hermes_cli.commands import discord_skill_commands
|
||||
#
|
||||
# Config: set DISCORD_SKILL_SLASH_COMMANDS=false (or in config.yaml
|
||||
# under discord.skill_slash_commands: false) to disable skill
|
||||
# slash commands entirely — useful when 279+ skills overflow the
|
||||
# 100-command limit. Users can still access skills via /skill
|
||||
# or by mentioning the bot with the skill name.
|
||||
_skill_slash_enabled = os.getenv("DISCORD_SKILL_SLASH_COMMANDS", "true").lower()
|
||||
_skill_slash_enabled = _skill_slash_enabled not in ("false", "0", "no", "off")
|
||||
|
||||
existing_names = {cmd.name for cmd in tree.get_commands()}
|
||||
remaining_slots = max(0, _DISCORD_CMD_LIMIT - len(existing_names))
|
||||
|
||||
skill_entries, skipped = discord_skill_commands(
|
||||
max_slots=remaining_slots,
|
||||
reserved_names=existing_names,
|
||||
if not _skill_slash_enabled:
|
||||
logger.info(
|
||||
"[%s] Discord skill slash commands disabled (DISCORD_SKILL_SLASH_COMMANDS=false). "
|
||||
"Skills accessible via /skill or text mention.",
|
||||
self.name,
|
||||
)
|
||||
else:
|
||||
_DISCORD_CMD_LIMIT = 100
|
||||
try:
|
||||
from hermes_cli.commands import discord_skill_commands
|
||||
|
||||
for discord_name, description, cmd_key in skill_entries:
|
||||
# Closure factory to capture cmd_key per iteration
|
||||
def _make_skill_handler(_key: str):
|
||||
async def _skill_slash(interaction: discord.Interaction, args: str = ""):
|
||||
await self._run_simple_slash(interaction, f"{_key} {args}".strip())
|
||||
return _skill_slash
|
||||
existing_names = {cmd.name for cmd in tree.get_commands()}
|
||||
remaining_slots = max(0, _DISCORD_CMD_LIMIT - len(existing_names))
|
||||
|
||||
handler = _make_skill_handler(cmd_key)
|
||||
handler.__name__ = f"skill_{discord_name.replace('-', '_')}"
|
||||
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description,
|
||||
callback=handler,
|
||||
skill_entries, skipped = discord_skill_commands(
|
||||
max_slots=remaining_slots,
|
||||
reserved_names=existing_names,
|
||||
)
|
||||
discord.app_commands.describe(args="Optional arguments for the skill")(cmd)
|
||||
tree.add_command(cmd)
|
||||
|
||||
if skipped:
|
||||
logger.warning(
|
||||
"[%s] Discord slash command limit reached (%d): %d skill(s) not registered",
|
||||
self.name, _DISCORD_CMD_LIMIT, skipped,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Failed to register skill slash commands: %s", self.name, exc)
|
||||
for discord_name, description, cmd_key in skill_entries:
|
||||
# Closure factory to capture cmd_key per iteration
|
||||
def _make_skill_handler(_key: str):
|
||||
async def _skill_slash(interaction: discord.Interaction, args: str = ""):
|
||||
await self._run_simple_slash(interaction, f"{_key} {args}".strip())
|
||||
return _skill_slash
|
||||
|
||||
handler = _make_skill_handler(cmd_key)
|
||||
handler.__name__ = f"skill_{discord_name.replace('-', '_')}"
|
||||
|
||||
cmd = discord.app_commands.Command(
|
||||
name=discord_name,
|
||||
description=description,
|
||||
callback=handler,
|
||||
)
|
||||
discord.app_commands.describe(args="Optional arguments for the skill")(cmd)
|
||||
tree.add_command(cmd)
|
||||
|
||||
if skipped:
|
||||
logger.warning(
|
||||
"[%s] Discord slash command limit reached (%d): %d skill(s) not registered. "
|
||||
"Set DISCORD_SKILL_SLASH_COMMANDS=false to disable skill slash commands "
|
||||
"and use /skill or text mentions instead.",
|
||||
self.name, _DISCORD_CMD_LIMIT, skipped,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[%s] Failed to register skill slash commands: %s", self.name, exc)
|
||||
|
||||
def _build_slash_event(self, interaction: discord.Interaction, text: str) -> MessageEvent:
|
||||
"""Build a MessageEvent from a Discord slash command interaction."""
|
||||
|
||||
122
tests/test_gateway_config_debt_328.py
Normal file
122
tests/test_gateway_config_debt_328.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for gateway config validation — #328 config debt fixes."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from gateway.config import (
|
||||
SessionResetPolicy,
|
||||
GatewayConfig,
|
||||
Platform,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
|
||||
class TestSessionResetPolicyValidation:
|
||||
"""Tests for early validation in SessionResetPolicy.from_dict."""
|
||||
|
||||
def test_valid_idle_minutes(self):
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": 30})
|
||||
assert policy.idle_minutes == 30
|
||||
|
||||
def test_zero_idle_minutes_rejected(self):
|
||||
"""idle_minutes=0 must be rejected and default to 1440."""
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": 0})
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_negative_idle_minutes_rejected(self):
|
||||
"""Negative idle_minutes must be rejected and default to 1440."""
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": -10})
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_string_idle_minutes_rejected(self):
|
||||
"""Non-integer idle_minutes must be rejected."""
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": "abc"})
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_absurdly_large_idle_minutes_capped(self):
|
||||
"""idle_minutes exceeding 1 year must be capped."""
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": 9999999})
|
||||
assert policy.idle_minutes == 525600
|
||||
|
||||
def test_none_idle_minutes_uses_default(self):
|
||||
"""None idle_minutes should use default 1440."""
|
||||
policy = SessionResetPolicy.from_dict({"idle_minutes": None})
|
||||
assert policy.idle_minutes == 1440
|
||||
|
||||
def test_valid_at_hour(self):
|
||||
policy = SessionResetPolicy.from_dict({"at_hour": 12})
|
||||
assert policy.at_hour == 12
|
||||
|
||||
def test_invalid_at_hour_rejected(self):
|
||||
"""at_hour outside 0-23 must be rejected."""
|
||||
policy = SessionResetPolicy.from_dict({"at_hour": 25})
|
||||
assert policy.at_hour == 4
|
||||
|
||||
def test_negative_at_hour_rejected(self):
|
||||
policy = SessionResetPolicy.from_dict({"at_hour": -1})
|
||||
assert policy.at_hour == 4
|
||||
|
||||
def test_string_at_hour_rejected(self):
|
||||
policy = SessionResetPolicy.from_dict({"at_hour": "noon"})
|
||||
assert policy.at_hour == 4
|
||||
|
||||
def test_invalid_mode_rejected(self):
|
||||
"""Invalid mode must fall back to 'both'."""
|
||||
policy = SessionResetPolicy.from_dict({"mode": "invalid"})
|
||||
assert policy.mode == "both"
|
||||
|
||||
def test_valid_modes_accepted(self):
|
||||
for mode in ("daily", "idle", "both", "none"):
|
||||
policy = SessionResetPolicy.from_dict({"mode": mode})
|
||||
assert policy.mode == mode
|
||||
|
||||
def test_all_defaults(self):
|
||||
"""Empty dict should produce all defaults."""
|
||||
policy = SessionResetPolicy.from_dict({})
|
||||
assert policy.mode == "both"
|
||||
assert policy.at_hour == 4
|
||||
assert policy.idle_minutes == 1440
|
||||
assert policy.notify is True
|
||||
assert policy.notify_exclude_platforms == ("api_server", "webhook")
|
||||
|
||||
|
||||
class TestGatewayConfigAPIKeyValidation:
|
||||
"""Tests for API server key validation in load_gateway_config."""
|
||||
|
||||
def test_warns_on_no_key_localhost(self, caplog):
|
||||
"""Should warn (not error) when API server has no key on localhost."""
|
||||
with patch.dict(os.environ, {
|
||||
"API_SERVER_ENABLED": "true",
|
||||
"API_SERVER_KEY": "",
|
||||
}, clear=False):
|
||||
# Clear the key if it was set
|
||||
os.environ.pop("API_SERVER_KEY", None)
|
||||
os.environ["API_SERVER_ENABLED"] = "true"
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = load_gateway_config()
|
||||
# Should have a warning about unauthenticated API server
|
||||
assert any(
|
||||
"API_SERVER_KEY" in r.message or "No API key" in r.message
|
||||
for r in caplog.records
|
||||
if r.levelno >= logging.WARNING
|
||||
) or Platform.API_SERVER in config.platforms # at minimum, the platform should load
|
||||
|
||||
|
||||
class TestWeakCredentialExpansion:
|
||||
"""Tests that API provider keys are included in weak credential checks."""
|
||||
|
||||
def test_openrouter_key_in_min_lengths(self):
|
||||
from gateway.config import _MIN_TOKEN_LENGTHS
|
||||
assert "OPENROUTER_API_KEY" in _MIN_TOKEN_LENGTHS
|
||||
assert _MIN_TOKEN_LENGTHS["OPENROUTER_API_KEY"] == 20
|
||||
|
||||
def test_anthropic_key_in_min_lengths(self):
|
||||
from gateway.config import _MIN_TOKEN_LENGTHS
|
||||
assert "ANTHROPIC_API_KEY" in _MIN_TOKEN_LENGTHS
|
||||
|
||||
def test_openai_key_in_min_lengths(self):
|
||||
from gateway.config import _MIN_TOKEN_LENGTHS
|
||||
assert "OPENAI_API_KEY" in _MIN_TOKEN_LENGTHS
|
||||
Reference in New Issue
Block a user