Compare commits
10 Commits
fix/memory
...
burn/350-1
| Author | SHA1 | Date | |
|---|---|---|---|
| 54590dd627 | |||
| 378476570d | |||
| 4da1d8be88 | |||
| 1ec02cf061 | |||
|
|
1156875cb5 | ||
| f4c102400e | |||
| 6555ccabc1 | |||
|
|
8c712866c4 | ||
| 8fb59aae64 | |||
|
|
aa6eabb816 |
18
config/dispatch-config.json
Normal file
18
config/dispatch-config.json
Normal file
@@ -0,0 +1,18 @@
|
||||
{
|
||||
"agents": {
|
||||
"ezra": {
|
||||
"host": "143.198.27.163",
|
||||
"hermes_path": "/root/wizards/ezra/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
},
|
||||
"timmy": {
|
||||
"host": "timmy",
|
||||
"hermes_path": "/root/wizards/timmy/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
}
|
||||
},
|
||||
"validation_timeout": 30,
|
||||
"command_timeout": 300,
|
||||
"max_retries": 2,
|
||||
"retry_delay": 5
|
||||
}
|
||||
551
cron/dispatch_worker.py
Normal file
551
cron/dispatch_worker.py
Normal file
@@ -0,0 +1,551 @@
|
||||
"""
|
||||
VPS Agent Dispatch Worker for Hermes Cron System
|
||||
|
||||
This module provides a dispatch worker that SSHs into remote VPS machines
|
||||
and runs hermes commands. It ensures that:
|
||||
|
||||
1. Remote dispatch only counts as success when the remote hermes command actually launches
|
||||
2. Stale per-agent hermes binary paths are configurable/validated before queue drain
|
||||
3. Failed remote launches remain in the queue (or are marked failed) instead of being reported as OK
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DispatchStatus(Enum):
|
||||
"""Status of a dispatch operation."""
|
||||
PENDING = "pending"
|
||||
VALIDATING = "validating"
|
||||
DISPATCHING = "dispatching"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
RETRYING = "retrying"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DispatchResult:
|
||||
"""Result of a dispatch operation."""
|
||||
status: DispatchStatus
|
||||
message: str
|
||||
exit_code: Optional[int] = None
|
||||
stdout: Optional[str] = None
|
||||
stderr: Optional[str] = None
|
||||
execution_time: Optional[float] = None
|
||||
hermes_path: Optional[str] = None
|
||||
validated: bool = False
|
||||
|
||||
|
||||
class HermesPathValidator:
|
||||
"""Validates hermes binary paths on remote VPS machines."""
|
||||
|
||||
def __init__(self, ssh_key_path: Optional[str] = None):
|
||||
self.ssh_key_path = ssh_key_path or os.path.expanduser("~/.ssh/id_rsa")
|
||||
self.timeout = 30 # SSH timeout in seconds
|
||||
|
||||
def validate_hermes_path(self, host: str, hermes_path: str,
|
||||
username: str = "root") -> DispatchResult:
|
||||
"""
|
||||
Validate that the hermes binary exists and is executable on the remote host.
|
||||
|
||||
Args:
|
||||
host: Remote host IP or hostname
|
||||
hermes_path: Path to hermes binary on remote host
|
||||
username: SSH username
|
||||
|
||||
Returns:
|
||||
DispatchResult with validation status
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Build SSH command to check hermes binary
|
||||
ssh_cmd = [
|
||||
"ssh",
|
||||
"-i", self.ssh_key_path,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "ConnectTimeout=10",
|
||||
"-o", "BatchMode=yes",
|
||||
f"{username}@{host}",
|
||||
f"test -x {hermes_path} && echo 'VALID' || echo 'INVALID'"
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.timeout
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0 and "VALID" in result.stdout:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.SUCCESS,
|
||||
message=f"Hermes binary validated at {hermes_path}",
|
||||
exit_code=0,
|
||||
execution_time=execution_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=True
|
||||
)
|
||||
else:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Hermes binary not found or not executable: {hermes_path}",
|
||||
exit_code=result.returncode,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"SSH timeout validating hermes path on {host}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
except Exception as e:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Error validating hermes path: {str(e)}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=hermes_path,
|
||||
validated=False
|
||||
)
|
||||
|
||||
|
||||
class VPSAgentDispatcher:
|
||||
"""Dispatches hermes commands to remote VPS agents."""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
self.config_path = config_path or os.path.expanduser("~/.hermes/dispatch_config.json")
|
||||
self.validator = HermesPathValidator()
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load dispatch configuration."""
|
||||
try:
|
||||
if os.path.exists(self.config_path):
|
||||
with open(self.config_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load dispatch config: {e}")
|
||||
|
||||
# Default configuration
|
||||
return {
|
||||
"agents": {
|
||||
"ezra": {
|
||||
"host": "143.198.27.163",
|
||||
"hermes_path": "/root/wizards/ezra/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
},
|
||||
"timmy": {
|
||||
"host": "timmy",
|
||||
"hermes_path": "/root/wizards/timmy/hermes-agent/venv/bin/hermes",
|
||||
"username": "root"
|
||||
}
|
||||
},
|
||||
"validation_timeout": 30,
|
||||
"command_timeout": 300,
|
||||
"max_retries": 2,
|
||||
"retry_delay": 5
|
||||
}
|
||||
|
||||
def save_config(self):
|
||||
"""Save dispatch configuration."""
|
||||
try:
|
||||
config_dir = Path(self.config_path).parent
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.config_path, 'w') as f:
|
||||
json.dump(self.config, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(self.config_path, 0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dispatch config: {e}")
|
||||
|
||||
def get_agent_config(self, agent_name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get configuration for a specific agent."""
|
||||
return self.config.get("agents", {}).get(agent_name)
|
||||
|
||||
def update_agent_config(self, agent_name: str, host: str, hermes_path: str,
|
||||
username: str = "root"):
|
||||
"""Update configuration for a specific agent."""
|
||||
if "agents" not in self.config:
|
||||
self.config["agents"] = {}
|
||||
|
||||
self.config["agents"][agent_name] = {
|
||||
"host": host,
|
||||
"hermes_path": hermes_path,
|
||||
"username": username
|
||||
}
|
||||
|
||||
self.save_config()
|
||||
|
||||
def validate_agent(self, agent_name: str) -> DispatchResult:
|
||||
"""Validate that an agent's hermes binary is accessible."""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
return self.validator.validate_hermes_path(
|
||||
host=agent_config["host"],
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
username=agent_config.get("username", "root")
|
||||
)
|
||||
|
||||
def dispatch_command(self, agent_name: str, command: str,
|
||||
validate_first: bool = True) -> DispatchResult:
|
||||
"""
|
||||
Dispatch a command to a remote VPS agent.
|
||||
|
||||
Args:
|
||||
agent_name: Name of the agent to dispatch to
|
||||
command: Command to execute
|
||||
validate_first: Whether to validate hermes path before dispatching
|
||||
|
||||
Returns:
|
||||
DispatchResult with execution status
|
||||
"""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
# Validate hermes path if requested
|
||||
if validate_first:
|
||||
validation_result = self.validate_agent(agent_name)
|
||||
if validation_result.status != DispatchStatus.SUCCESS:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Validation failed: {validation_result.message}",
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=False
|
||||
)
|
||||
|
||||
# Build SSH command to execute hermes command
|
||||
ssh_cmd = [
|
||||
"ssh",
|
||||
"-i", self.validator.ssh_key_path,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "ConnectTimeout=10",
|
||||
f"{agent_config.get('username', 'root')}@{agent_config['host']}",
|
||||
f"cd /root/wizards/{agent_name}/hermes-agent && source venv/bin/activate && {command}"
|
||||
]
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
result = subprocess.run(
|
||||
ssh_cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=self.config.get("command_timeout", 300)
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if result.returncode == 0:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.SUCCESS,
|
||||
message=f"Command executed successfully on {agent_name}",
|
||||
exit_code=0,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
else:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Command failed on {agent_name}: {result.stderr}",
|
||||
exit_code=result.returncode,
|
||||
stdout=result.stdout,
|
||||
stderr=result.stderr,
|
||||
execution_time=execution_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
|
||||
except subprocess.TimeoutExpired:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Command timeout on {agent_name}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
except Exception as e:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Error executing command on {agent_name}: {str(e)}",
|
||||
execution_time=time.time() - start_time,
|
||||
hermes_path=agent_config["hermes_path"],
|
||||
validated=validate_first
|
||||
)
|
||||
|
||||
def dispatch_hermes_command(self, agent_name: str, hermes_command: str,
|
||||
validate_first: bool = True) -> DispatchResult:
|
||||
"""
|
||||
Dispatch a hermes command to a remote VPS agent.
|
||||
|
||||
Args:
|
||||
agent_name: Name of the agent to dispatch to
|
||||
hermes_command: Hermes command to execute (e.g., "hermes cron list")
|
||||
validate_first: Whether to validate hermes path before dispatching
|
||||
|
||||
Returns:
|
||||
DispatchResult with execution status
|
||||
"""
|
||||
agent_config = self.get_agent_config(agent_name)
|
||||
if not agent_config:
|
||||
return DispatchResult(
|
||||
status=DispatchStatus.FAILED,
|
||||
message=f"Agent configuration not found: {agent_name}"
|
||||
)
|
||||
|
||||
# Build full hermes command
|
||||
full_command = f"{agent_config['hermes_path']} {hermes_command}"
|
||||
|
||||
return self.dispatch_command(agent_name, full_command, validate_first)
|
||||
|
||||
|
||||
class DispatchQueue:
|
||||
"""Queue for managing dispatch operations."""
|
||||
|
||||
def __init__(self, queue_file: Optional[str] = None):
|
||||
self.queue_file = queue_file or os.path.expanduser("~/.hermes/dispatch_queue.json")
|
||||
self.queue: List[Dict[str, Any]] = self._load_queue()
|
||||
|
||||
def _load_queue(self) -> List[Dict[str, Any]]:
|
||||
"""Load queue from file."""
|
||||
try:
|
||||
if os.path.exists(self.queue_file):
|
||||
with open(self.queue_file, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load dispatch queue: {e}")
|
||||
|
||||
return []
|
||||
|
||||
def save_queue(self):
|
||||
"""Save queue to file."""
|
||||
try:
|
||||
queue_dir = Path(self.queue_file).parent
|
||||
queue_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(self.queue_file, 'w') as f:
|
||||
json.dump(self.queue, f, indent=2)
|
||||
|
||||
# Set secure permissions
|
||||
os.chmod(self.queue_file, 0o600)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save dispatch queue: {e}")
|
||||
|
||||
def add_item(self, agent_name: str, command: str, priority: int = 0,
|
||||
max_retries: int = 3) -> str:
|
||||
"""
|
||||
Add an item to the dispatch queue.
|
||||
|
||||
Returns:
|
||||
Queue item ID
|
||||
"""
|
||||
item_id = f"dispatch_{int(time.time())}_{len(self.queue)}"
|
||||
|
||||
item = {
|
||||
"id": item_id,
|
||||
"agent_name": agent_name,
|
||||
"command": command,
|
||||
"priority": priority,
|
||||
"max_retries": max_retries,
|
||||
"retry_count": 0,
|
||||
"status": DispatchStatus.PENDING.value,
|
||||
"created_at": time.time(),
|
||||
"last_attempt": None,
|
||||
"result": None
|
||||
}
|
||||
|
||||
self.queue.append(item)
|
||||
self.save_queue()
|
||||
|
||||
return item_id
|
||||
|
||||
def get_next_item(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get the next item from the queue (highest priority, oldest first)."""
|
||||
if not self.queue:
|
||||
return None
|
||||
|
||||
# Sort by priority (descending) and created_at (ascending)
|
||||
sorted_queue = sorted(
|
||||
self.queue,
|
||||
key=lambda x: (-x.get("priority", 0), x.get("created_at", 0))
|
||||
)
|
||||
|
||||
# Find first pending item
|
||||
for item in sorted_queue:
|
||||
if item.get("status") == DispatchStatus.PENDING.value:
|
||||
return item
|
||||
|
||||
return None
|
||||
|
||||
def update_item(self, item_id: str, status: DispatchStatus,
|
||||
result: Optional[DispatchResult] = None):
|
||||
"""Update a queue item."""
|
||||
for item in self.queue:
|
||||
if item.get("id") == item_id:
|
||||
item["status"] = status.value
|
||||
item["last_attempt"] = time.time()
|
||||
|
||||
if result:
|
||||
item["result"] = {
|
||||
"status": result.status.value,
|
||||
"message": result.message,
|
||||
"exit_code": result.exit_code,
|
||||
"stdout": result.stdout,
|
||||
"stderr": result.stderr,
|
||||
"execution_time": result.execution_time,
|
||||
"hermes_path": result.hermes_path,
|
||||
"validated": result.validated
|
||||
}
|
||||
|
||||
# Update retry count if failed
|
||||
if status == DispatchStatus.FAILED:
|
||||
item["retry_count"] = item.get("retry_count", 0) + 1
|
||||
|
||||
self.save_queue()
|
||||
break
|
||||
|
||||
def remove_item(self, item_id: str):
|
||||
"""Remove an item from the queue."""
|
||||
self.queue = [item for item in self.queue if item.get("id") != item_id]
|
||||
self.save_queue()
|
||||
|
||||
def get_failed_items(self) -> List[Dict[str, Any]]:
|
||||
"""Get all failed items that can be retried."""
|
||||
return [
|
||||
item for item in self.queue
|
||||
if item.get("status") == DispatchStatus.FAILED.value
|
||||
and item.get("retry_count", 0) < item.get("max_retries", 3)
|
||||
]
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get queue statistics."""
|
||||
total = len(self.queue)
|
||||
pending = sum(1 for item in self.queue if item.get("status") == DispatchStatus.PENDING.value)
|
||||
success = sum(1 for item in self.queue if item.get("status") == DispatchStatus.SUCCESS.value)
|
||||
failed = sum(1 for item in self.queue if item.get("status") == DispatchStatus.FAILED.value)
|
||||
|
||||
return {
|
||||
"total": total,
|
||||
"pending": pending,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"retryable": len(self.get_failed_items())
|
||||
}
|
||||
|
||||
|
||||
def process_dispatch_queue(dispatcher: VPSAgentDispatcher,
|
||||
queue: DispatchQueue,
|
||||
batch_size: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Process items from the dispatch queue.
|
||||
|
||||
Args:
|
||||
dispatcher: VPS agent dispatcher
|
||||
queue: Dispatch queue
|
||||
batch_size: Number of items to process in this batch
|
||||
|
||||
Returns:
|
||||
Processing statistics
|
||||
"""
|
||||
processed = 0
|
||||
success = 0
|
||||
failed = 0
|
||||
|
||||
for _ in range(batch_size):
|
||||
item = queue.get_next_item()
|
||||
if not item:
|
||||
break
|
||||
|
||||
item_id = item["id"]
|
||||
agent_name = item["agent_name"]
|
||||
command = item["command"]
|
||||
|
||||
# Update status to dispatching
|
||||
queue.update_item(item_id, DispatchStatus.DISPATCHING)
|
||||
|
||||
# Dispatch the command
|
||||
result = dispatcher.dispatch_hermes_command(
|
||||
agent_name=agent_name,
|
||||
hermes_command=command,
|
||||
validate_first=True
|
||||
)
|
||||
|
||||
# Update queue with result
|
||||
if result.status == DispatchStatus.SUCCESS:
|
||||
queue.update_item(item_id, DispatchStatus.SUCCESS, result)
|
||||
success += 1
|
||||
else:
|
||||
# Check if we should retry
|
||||
item_data = next((i for i in queue.queue if i.get("id") == item_id), None)
|
||||
if item_data and item_data.get("retry_count", 0) < item_data.get("max_retries", 3):
|
||||
queue.update_item(item_id, DispatchStatus.FAILED, result)
|
||||
failed += 1
|
||||
else:
|
||||
# Max retries reached, remove from queue
|
||||
queue.remove_item(item_id)
|
||||
failed += 1
|
||||
|
||||
processed += 1
|
||||
|
||||
return {
|
||||
"processed": processed,
|
||||
"success": success,
|
||||
"failed": failed,
|
||||
"queue_stats": queue.get_stats()
|
||||
}
|
||||
|
||||
|
||||
# Example usage and testing
|
||||
if __name__ == "__main__":
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create dispatcher and queue
|
||||
dispatcher = VPSAgentDispatcher()
|
||||
queue = DispatchQueue()
|
||||
|
||||
# Example: Add items to queue
|
||||
queue.add_item("ezra", "cron list")
|
||||
queue.add_item("timmy", "cron status")
|
||||
|
||||
# Process queue
|
||||
stats = process_dispatch_queue(dispatcher, queue)
|
||||
print(f"Processing stats: {stats}")
|
||||
|
||||
# Show queue stats
|
||||
queue_stats = queue.get_stats()
|
||||
print(f"Queue stats: {queue_stats}")
|
||||
@@ -653,6 +653,12 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
# AIAgent.__init__ is missing params the scheduler expects.
|
||||
_validate_agent_interface()
|
||||
|
||||
# Check if this is a dispatch job
|
||||
if job.get("type") == "dispatch" or "dispatch" in job.get("name", "").lower():
|
||||
return _run_dispatch_job(job)
|
||||
|
||||
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
# Initialize SQLite session store so cron job messages are persisted
|
||||
@@ -1007,6 +1013,89 @@ def run_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
logger.debug("Job '%s': failed to close SQLite session store: %s", job_id, e)
|
||||
|
||||
|
||||
|
||||
def _run_dispatch_job(job: dict) -> tuple[bool, str, str, Optional[str]]:
|
||||
"""
|
||||
Execute a dispatch job that SSHs into remote VPS machines.
|
||||
|
||||
Returns:
|
||||
Tuple of (success, full_output_doc, final_response, error_message)
|
||||
"""
|
||||
from cron.dispatch_worker import VPSAgentDispatcher, DispatchQueue, process_dispatch_queue
|
||||
|
||||
job_id = job["id"]
|
||||
job_name = job["name"]
|
||||
|
||||
logger.info("Running dispatch job '%s' (ID: %s)", job_name, job_id)
|
||||
|
||||
try:
|
||||
# Load dispatch configuration
|
||||
dispatcher = VPSAgentDispatcher()
|
||||
queue = DispatchQueue()
|
||||
|
||||
# Get dispatch parameters from job
|
||||
agent_name = job.get("agent_name", "ezra")
|
||||
command = job.get("command", "cron list")
|
||||
batch_size = job.get("batch_size", 5)
|
||||
|
||||
# Add command to queue if specified
|
||||
if command:
|
||||
queue.add_item(agent_name, command)
|
||||
|
||||
# Process the dispatch queue
|
||||
stats = process_dispatch_queue(dispatcher, queue, batch_size)
|
||||
|
||||
# Generate output
|
||||
output = f"""# Dispatch Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Agent:** {agent_name}
|
||||
**Command:** {command}
|
||||
|
||||
## Dispatch Results
|
||||
|
||||
- **Processed:** {stats['processed']}
|
||||
- **Success:** {stats['success']}
|
||||
- **Failed:** {stats['failed']}
|
||||
|
||||
## Queue Statistics
|
||||
|
||||
- **Total items:** {stats['queue_stats']['total']}
|
||||
- **Pending:** {stats['queue_stats']['pending']}
|
||||
- **Success:** {stats['queue_stats']['success']}
|
||||
- **Failed:** {stats['queue_stats']['failed']}
|
||||
- **Retryable:** {stats['queue_stats']['retryable']}
|
||||
|
||||
## Status
|
||||
|
||||
{"✅ All dispatches successful" if stats['failed'] == 0 else f"⚠️ {stats['failed']} dispatches failed"}
|
||||
"""
|
||||
|
||||
success = stats['failed'] == 0
|
||||
error_message = None if success else f"{stats['failed']} dispatches failed"
|
||||
|
||||
return (success, output, output, error_message)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Dispatch job failed: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
output = f"""# Dispatch Job: {job_name}
|
||||
|
||||
**Job ID:** {job_id}
|
||||
**Run Time:** {_hermes_now().strftime('%Y-%m-%d %H:%M:%S')}
|
||||
**Status:** ❌ Failed
|
||||
|
||||
## Error
|
||||
|
||||
{error_msg}
|
||||
"""
|
||||
|
||||
return (False, output, output, error_msg)
|
||||
|
||||
|
||||
|
||||
def tick(verbose: bool = True, adapters=None, loop=None) -> int:
|
||||
"""
|
||||
Check and run all due jobs.
|
||||
|
||||
@@ -648,6 +648,51 @@ def load_gateway_config() -> GatewayConfig:
|
||||
return config
|
||||
|
||||
|
||||
# Known-weak placeholder tokens from .env.example, tutorials, etc.
|
||||
_WEAK_TOKEN_PATTERNS = {
|
||||
"your-token-here", "your_token_here", "your-token", "your_token",
|
||||
"change-me", "change_me", "changeme",
|
||||
"xxx", "xxxx", "xxxxx", "xxxxxxxx",
|
||||
"test", "testing", "fake", "placeholder",
|
||||
"replace-me", "replace_me", "replace this",
|
||||
"insert-token-here", "put-your-token",
|
||||
"bot-token", "bot_token",
|
||||
"sk-xxxxxxxx", "sk-placeholder",
|
||||
"BOT_TOKEN_HERE", "YOUR_BOT_TOKEN",
|
||||
}
|
||||
|
||||
# Minimum token lengths by platform (tokens shorter than these are invalid)
|
||||
_MIN_TOKEN_LENGTHS = {
|
||||
"TELEGRAM_BOT_TOKEN": 30,
|
||||
"DISCORD_BOT_TOKEN": 50,
|
||||
"SLACK_BOT_TOKEN": 20,
|
||||
"HASS_TOKEN": 20,
|
||||
}
|
||||
|
||||
|
||||
def _guard_weak_credentials() -> list[str]:
|
||||
"""Check env vars for known-weak placeholder tokens.
|
||||
|
||||
Returns a list of warning messages for any weak credentials found.
|
||||
"""
|
||||
warnings = []
|
||||
for env_var, min_len in _MIN_TOKEN_LENGTHS.items():
|
||||
value = os.getenv(env_var, "").strip()
|
||||
if not value:
|
||||
continue
|
||||
if value.lower() in _WEAK_TOKEN_PATTERNS:
|
||||
warnings.append(
|
||||
f"{env_var} is set to a placeholder value ('{value[:20]}'). "
|
||||
f"Replace it with a real token."
|
||||
)
|
||||
elif len(value) < min_len:
|
||||
warnings.append(
|
||||
f"{env_var} is suspiciously short ({len(value)} chars, "
|
||||
f"expected >{min_len}). May be truncated or invalid."
|
||||
)
|
||||
return warnings
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
@@ -941,3 +986,7 @@ def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
config.default_reset_policy.at_hour = int(reset_hour)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Guard against weak placeholder tokens from .env.example copies
|
||||
for warning in _guard_weak_credentials():
|
||||
logger.warning("Weak credential: %s", warning)
|
||||
|
||||
@@ -540,6 +540,29 @@ def handle_function_call(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Poka-yoke: validate tool handler return type.
|
||||
# Handlers MUST return a JSON string. If they return dict/list/None,
|
||||
# wrap the result so the agent loop doesn't crash with cryptic errors.
|
||||
if not isinstance(result, str):
|
||||
logger.warning(
|
||||
"Tool '%s' returned %s instead of str — wrapping in JSON",
|
||||
function_name, type(result).__name__,
|
||||
)
|
||||
result = json.dumps(
|
||||
{"output": str(result), "_type_warning": f"Tool returned {type(result).__name__}, expected str"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
else:
|
||||
# Validate it's parseable JSON
|
||||
try:
|
||||
json.loads(result)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
logger.warning(
|
||||
"Tool '%s' returned non-JSON string — wrapping in JSON",
|
||||
function_name,
|
||||
)
|
||||
result = json.dumps({"output": result}, ensure_ascii=False)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -12,7 +12,7 @@ Config in $HERMES_HOME/config.yaml (profile-scoped):
|
||||
auto_extract: false
|
||||
default_trust: 0.5
|
||||
min_trust_threshold: 0.3
|
||||
temporal_decay_half_life: 0
|
||||
temporal_decay_half_life: 60
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -152,6 +152,7 @@ class HolographicMemoryProvider(MemoryProvider):
|
||||
{"key": "auto_extract", "description": "Auto-extract facts at session end", "default": "false", "choices": ["true", "false"]},
|
||||
{"key": "default_trust", "description": "Default trust score for new facts", "default": "0.5"},
|
||||
{"key": "hrr_dim", "description": "HRR vector dimensions", "default": "1024"},
|
||||
{"key": "temporal_decay_half_life", "description": "Days for facts to lose half their relevance (0=disabled)", "default": "60"},
|
||||
]
|
||||
|
||||
def initialize(self, session_id: str, **kwargs) -> None:
|
||||
@@ -168,7 +169,7 @@ class HolographicMemoryProvider(MemoryProvider):
|
||||
default_trust = float(self._config.get("default_trust", 0.5))
|
||||
hrr_dim = int(self._config.get("hrr_dim", 1024))
|
||||
hrr_weight = float(self._config.get("hrr_weight", 0.3))
|
||||
temporal_decay = int(self._config.get("temporal_decay_half_life", 0))
|
||||
temporal_decay = int(self._config.get("temporal_decay_half_life", 60))
|
||||
|
||||
self._store = MemoryStore(db_path=db_path, default_trust=default_trust, hrr_dim=hrr_dim)
|
||||
self._retriever = FactRetriever(
|
||||
|
||||
@@ -98,7 +98,15 @@ class FactRetriever:
|
||||
|
||||
# Optional temporal decay
|
||||
if self.half_life > 0:
|
||||
score *= self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
|
||||
decay = self._temporal_decay(fact.get("updated_at") or fact.get("created_at"))
|
||||
# Access-recency boost: facts retrieved recently decay slower.
|
||||
# A fact accessed within 1 half-life gets up to 1.5x the decay
|
||||
# factor, tapering to 1.0x (no boost) after 2 half-lives.
|
||||
last_accessed = fact.get("last_accessed_at")
|
||||
if last_accessed:
|
||||
access_boost = self._access_recency_boost(last_accessed)
|
||||
decay = min(1.0, decay * access_boost)
|
||||
score *= decay
|
||||
|
||||
fact["score"] = score
|
||||
scored.append(fact)
|
||||
@@ -591,3 +599,41 @@ class FactRetriever:
|
||||
return math.pow(0.5, age_days / self.half_life)
|
||||
except (ValueError, TypeError):
|
||||
return 1.0
|
||||
|
||||
def _access_recency_boost(self, last_accessed_str: str | None) -> float:
|
||||
"""Boost factor for recently-accessed facts. Range [1.0, 1.5].
|
||||
|
||||
Facts accessed within 1 half-life get up to 1.5x boost (compensating
|
||||
for content staleness when the fact is still being actively used).
|
||||
Boost decays linearly to 1.0 (no boost) at 2 half-lives.
|
||||
|
||||
Returns 1.0 if half-life is disabled or timestamp is missing.
|
||||
"""
|
||||
if not self.half_life or not last_accessed_str:
|
||||
return 1.0
|
||||
|
||||
try:
|
||||
if isinstance(last_accessed_str, str):
|
||||
ts = datetime.fromisoformat(last_accessed_str.replace("Z", "+00:00"))
|
||||
else:
|
||||
ts = last_accessed_str
|
||||
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
|
||||
age_days = (datetime.now(timezone.utc) - ts).total_seconds() / 86400
|
||||
if age_days < 0:
|
||||
return 1.5 # Future timestamp = just accessed
|
||||
|
||||
half_lives_since_access = age_days / self.half_life
|
||||
|
||||
if half_lives_since_access <= 1.0:
|
||||
# Within 1 half-life: linearly from 1.5 (just now) to 1.0 (at 1 HL)
|
||||
return 1.0 + 0.5 * (1.0 - half_lives_since_access)
|
||||
elif half_lives_since_access <= 2.0:
|
||||
# Between 1 and 2 half-lives: linearly from 1.0 to 1.0 (no boost)
|
||||
return 1.0
|
||||
else:
|
||||
return 1.0
|
||||
except (ValueError, TypeError):
|
||||
return 1.0
|
||||
|
||||
52
tests/gateway/test_weak_credential_guard.py
Normal file
52
tests/gateway/test_weak_credential_guard.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Tests for weak credential guard in gateway/config.py."""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from gateway.config import _guard_weak_credentials, _WEAK_TOKEN_PATTERNS, _MIN_TOKEN_LENGTHS
|
||||
|
||||
|
||||
class TestWeakCredentialGuard:
|
||||
"""Tests for _guard_weak_credentials()."""
|
||||
|
||||
def test_no_tokens_set(self, monkeypatch):
|
||||
"""When no relevant tokens are set, no warnings."""
|
||||
for var in _MIN_TOKEN_LENGTHS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
warnings = _guard_weak_credentials()
|
||||
assert warnings == []
|
||||
|
||||
def test_placeholder_token_detected(self, monkeypatch):
|
||||
"""Known-weak placeholder tokens are flagged."""
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "your-token-here")
|
||||
warnings = _guard_weak_credentials()
|
||||
assert len(warnings) == 1
|
||||
assert "TELEGRAM_BOT_TOKEN" in warnings[0]
|
||||
assert "placeholder" in warnings[0].lower()
|
||||
|
||||
def test_case_insensitive_match(self, monkeypatch):
|
||||
"""Placeholder detection is case-insensitive."""
|
||||
monkeypatch.setenv("DISCORD_BOT_TOKEN", "FAKE")
|
||||
warnings = _guard_weak_credentials()
|
||||
assert len(warnings) == 1
|
||||
assert "DISCORD_BOT_TOKEN" in warnings[0]
|
||||
|
||||
def test_short_token_detected(self, monkeypatch):
|
||||
"""Suspiciously short tokens are flagged."""
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "abc123") # 6 chars, min is 30
|
||||
warnings = _guard_weak_credentials()
|
||||
assert len(warnings) == 1
|
||||
assert "short" in warnings[0].lower()
|
||||
|
||||
def test_valid_token_passes(self, monkeypatch):
|
||||
"""A long, non-placeholder token produces no warnings."""
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "1234567890:ABCDEFGHIJKLMNOPQRSTUVWXYZ1234567")
|
||||
warnings = _guard_weak_credentials()
|
||||
assert warnings == []
|
||||
|
||||
def test_multiple_weak_tokens(self, monkeypatch):
|
||||
"""Multiple weak tokens each produce a warning."""
|
||||
monkeypatch.setenv("TELEGRAM_BOT_TOKEN", "change-me")
|
||||
monkeypatch.setenv("DISCORD_BOT_TOKEN", "xx") # short
|
||||
warnings = _guard_weak_credentials()
|
||||
assert len(warnings) == 2
|
||||
209
tests/plugins/memory/test_temporal_decay.py
Normal file
209
tests/plugins/memory/test_temporal_decay.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""Tests for temporal decay and access-recency boost in holographic memory (#241)."""
|
||||
|
||||
import math
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestTemporalDecay:
|
||||
"""Test _temporal_decay exponential decay formula."""
|
||||
|
||||
def _make_retriever(self, half_life=60):
|
||||
from plugins.memory.holographic.retrieval import FactRetriever
|
||||
store = MagicMock()
|
||||
return FactRetriever(store=store, temporal_decay_half_life=half_life)
|
||||
|
||||
def test_fresh_fact_no_decay(self):
|
||||
"""A fact updated today should have decay ≈ 1.0."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
decay = r._temporal_decay(now)
|
||||
assert decay > 0.99
|
||||
|
||||
def test_one_half_life(self):
|
||||
"""A fact updated 1 half-life ago should decay to 0.5."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=60)).isoformat()
|
||||
decay = r._temporal_decay(old)
|
||||
assert abs(decay - 0.5) < 0.01
|
||||
|
||||
def test_two_half_lives(self):
|
||||
"""A fact updated 2 half-lives ago should decay to 0.25."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=120)).isoformat()
|
||||
decay = r._temporal_decay(old)
|
||||
assert abs(decay - 0.25) < 0.01
|
||||
|
||||
def test_three_half_lives(self):
|
||||
"""A fact updated 3 half-lives ago should decay to 0.125."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=180)).isoformat()
|
||||
decay = r._temporal_decay(old)
|
||||
assert abs(decay - 0.125) < 0.01
|
||||
|
||||
def test_half_life_disabled(self):
|
||||
"""When half_life=0, decay should always be 1.0."""
|
||||
r = self._make_retriever(half_life=0)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=365)).isoformat()
|
||||
assert r._temporal_decay(old) == 1.0
|
||||
|
||||
def test_none_timestamp(self):
|
||||
"""Missing timestamp should return 1.0 (no decay)."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
assert r._temporal_decay(None) == 1.0
|
||||
|
||||
def test_empty_timestamp(self):
|
||||
r = self._make_retriever(half_life=60)
|
||||
assert r._temporal_decay("") == 1.0
|
||||
|
||||
def test_invalid_timestamp(self):
|
||||
"""Malformed timestamp should return 1.0 (fail open)."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
assert r._temporal_decay("not-a-date") == 1.0
|
||||
|
||||
def test_future_timestamp(self):
|
||||
"""Future timestamp should return 1.0 (no decay for future dates)."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
future = (datetime.now(timezone.utc) + timedelta(days=10)).isoformat()
|
||||
assert r._temporal_decay(future) == 1.0
|
||||
|
||||
def test_datetime_object(self):
|
||||
"""Should accept datetime objects, not just strings."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = datetime.now(timezone.utc) - timedelta(days=60)
|
||||
decay = r._temporal_decay(old)
|
||||
assert abs(decay - 0.5) < 0.01
|
||||
|
||||
def test_different_half_lives(self):
|
||||
"""30-day half-life should decay faster than 90-day."""
|
||||
r30 = self._make_retriever(half_life=30)
|
||||
r90 = self._make_retriever(half_life=90)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=45)).isoformat()
|
||||
assert r30._temporal_decay(old) < r90._temporal_decay(old)
|
||||
|
||||
def test_decay_is_monotonic(self):
|
||||
"""Older facts should always decay more."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
now = datetime.now(timezone.utc)
|
||||
d1 = r._temporal_decay((now - timedelta(days=10)).isoformat())
|
||||
d2 = r._temporal_decay((now - timedelta(days=30)).isoformat())
|
||||
d3 = r._temporal_decay((now - timedelta(days=60)).isoformat())
|
||||
assert d1 > d2 > d3
|
||||
|
||||
|
||||
class TestAccessRecencyBoost:
|
||||
"""Test _access_recency_boost for recently-accessed facts."""
|
||||
|
||||
def _make_retriever(self, half_life=60):
|
||||
from plugins.memory.holographic.retrieval import FactRetriever
|
||||
store = MagicMock()
|
||||
return FactRetriever(store=store, temporal_decay_half_life=half_life)
|
||||
|
||||
def test_just_accessed_max_boost(self):
|
||||
"""A fact accessed just now should get maximum boost (1.5)."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
boost = r._access_recency_boost(now)
|
||||
assert boost > 1.45 # Near 1.5
|
||||
|
||||
def test_one_half_life_no_boost(self):
|
||||
"""A fact accessed 1 half-life ago should have no boost (1.0)."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=60)).isoformat()
|
||||
boost = r._access_recency_boost(old)
|
||||
assert abs(boost - 1.0) < 0.01
|
||||
|
||||
def test_half_way_boost(self):
|
||||
"""A fact accessed 0.5 half-lives ago should get ~1.25 boost."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=30)).isoformat()
|
||||
boost = r._access_recency_boost(old)
|
||||
assert abs(boost - 1.25) < 0.05
|
||||
|
||||
def test_beyond_one_half_life_no_boost(self):
|
||||
"""Beyond 1 half-life, boost should be 1.0."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=90)).isoformat()
|
||||
boost = r._access_recency_boost(old)
|
||||
assert boost == 1.0
|
||||
|
||||
def test_disabled_no_boost(self):
|
||||
"""When half_life=0, boost should be 1.0."""
|
||||
r = self._make_retriever(half_life=0)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
assert r._access_recency_boost(now) == 1.0
|
||||
|
||||
def test_none_timestamp(self):
|
||||
r = self._make_retriever(half_life=60)
|
||||
assert r._access_recency_boost(None) == 1.0
|
||||
|
||||
def test_invalid_timestamp(self):
|
||||
r = self._make_retriever(half_life=60)
|
||||
assert r._access_recency_boost("bad") == 1.0
|
||||
|
||||
def test_boost_range(self):
|
||||
"""Boost should always be in [1.0, 1.5]."""
|
||||
r = self._make_retriever(half_life=60)
|
||||
now = datetime.now(timezone.utc)
|
||||
for days in [0, 1, 15, 30, 45, 59, 60, 90, 365]:
|
||||
ts = (now - timedelta(days=days)).isoformat()
|
||||
boost = r._access_recency_boost(ts)
|
||||
assert 1.0 <= boost <= 1.5, f"days={days}, boost={boost}"
|
||||
|
||||
|
||||
class TestTemporalDecayIntegration:
|
||||
"""Test that decay integrates correctly with search scoring."""
|
||||
|
||||
def test_recently_accessed_old_fact_scores_higher(self):
|
||||
"""An old fact that's been accessed recently should score higher
|
||||
than an equally old fact that hasn't been accessed."""
|
||||
from plugins.memory.holographic.retrieval import FactRetriever
|
||||
store = MagicMock()
|
||||
r = FactRetriever(store=store, temporal_decay_half_life=60)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
old_date = (now - timedelta(days=120)).isoformat() # 2 half-lives old
|
||||
recent_access = (now - timedelta(days=10)).isoformat() # accessed 10 days ago
|
||||
old_access = (now - timedelta(days=200)).isoformat() # accessed 200 days ago
|
||||
|
||||
# Old fact, recently accessed
|
||||
decay1 = r._temporal_decay(old_date)
|
||||
boost1 = r._access_recency_boost(recent_access)
|
||||
effective1 = min(1.0, decay1 * boost1)
|
||||
|
||||
# Old fact, not recently accessed
|
||||
decay2 = r._temporal_decay(old_date)
|
||||
boost2 = r._access_recency_boost(old_access)
|
||||
effective2 = min(1.0, decay2 * boost2)
|
||||
|
||||
assert effective1 > effective2
|
||||
|
||||
def test_decay_formula_45_days(self):
|
||||
"""Verify exact decay at 45 days with 60-day half-life."""
|
||||
from plugins.memory.holographic.retrieval import FactRetriever
|
||||
r = FactRetriever(store=MagicMock(), temporal_decay_half_life=60)
|
||||
old = (datetime.now(timezone.utc) - timedelta(days=45)).isoformat()
|
||||
decay = r._temporal_decay(old)
|
||||
expected = math.pow(0.5, 45/60)
|
||||
assert abs(decay - expected) < 0.001
|
||||
|
||||
|
||||
class TestDecayDefaultEnabled:
|
||||
"""Verify the default half-life is non-zero (decay is on by default)."""
|
||||
|
||||
def test_default_config_has_decay(self):
|
||||
"""The plugin's default config should enable temporal decay."""
|
||||
from plugins.memory.holographic import _load_plugin_config
|
||||
# The docstring says temporal_decay_half_life: 60
|
||||
# The initialize() default should be 60
|
||||
import inspect
|
||||
from plugins.memory.holographic import HolographicMemoryProvider
|
||||
src = inspect.getsource(HolographicMemoryProvider.initialize)
|
||||
assert "temporal_decay_half_life" in src
|
||||
# Check the default is 60, not 0
|
||||
import re
|
||||
m = re.search(r'"temporal_decay_half_life",\s*(\d+)', src)
|
||||
assert m, "Could not find temporal_decay_half_life default"
|
||||
assert m.group(1) == "60", f"Default is {m.group(1)}, expected 60"
|
||||
@@ -137,3 +137,78 @@ class TestBackwardCompat:
|
||||
def test_tool_to_toolset_map(self):
|
||||
assert isinstance(TOOL_TO_TOOLSET_MAP, dict)
|
||||
assert len(TOOL_TO_TOOLSET_MAP) > 0
|
||||
|
||||
|
||||
class TestToolReturnTypeValidation:
|
||||
"""Poka-yoke: tool handlers must return JSON strings."""
|
||||
|
||||
def test_handler_returning_dict_is_wrapped(self, monkeypatch):
|
||||
"""A handler that returns a dict should be auto-wrapped to JSON string."""
|
||||
from tools.registry import registry
|
||||
from model_tools import handle_function_call
|
||||
import json
|
||||
|
||||
# Register a bad handler that returns dict instead of str
|
||||
registry.register(
|
||||
name="__test_bad_dict",
|
||||
toolset="test",
|
||||
schema={"name": "__test_bad_dict", "description": "test", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: {"this is": "a dict not a string"},
|
||||
)
|
||||
result = handle_function_call("__test_bad_dict", {})
|
||||
parsed = json.loads(result)
|
||||
assert "output" in parsed
|
||||
assert "_type_warning" in parsed
|
||||
# Cleanup
|
||||
registry._tools.pop("__test_bad_dict", None)
|
||||
|
||||
def test_handler_returning_none_is_wrapped(self, monkeypatch):
|
||||
"""A handler that returns None should be auto-wrapped."""
|
||||
from tools.registry import registry
|
||||
from model_tools import handle_function_call
|
||||
import json
|
||||
|
||||
registry.register(
|
||||
name="__test_bad_none",
|
||||
toolset="test",
|
||||
schema={"name": "__test_bad_none", "description": "test", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: None,
|
||||
)
|
||||
result = handle_function_call("__test_bad_none", {})
|
||||
parsed = json.loads(result)
|
||||
assert "_type_warning" in parsed
|
||||
registry._tools.pop("__test_bad_none", None)
|
||||
|
||||
def test_handler_returning_non_json_string_is_wrapped(self):
|
||||
"""A handler returning a plain string (not JSON) should be wrapped."""
|
||||
from tools.registry import registry
|
||||
from model_tools import handle_function_call
|
||||
import json
|
||||
|
||||
registry.register(
|
||||
name="__test_bad_plain",
|
||||
toolset="test",
|
||||
schema={"name": "__test_bad_plain", "description": "test", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: "just a plain string, not json",
|
||||
)
|
||||
result = handle_function_call("__test_bad_plain", {})
|
||||
parsed = json.loads(result)
|
||||
assert "output" in parsed
|
||||
registry._tools.pop("__test_bad_plain", None)
|
||||
|
||||
def test_handler_returning_valid_json_passes_through(self):
|
||||
"""A handler returning valid JSON string passes through unchanged."""
|
||||
from tools.registry import registry
|
||||
from model_tools import handle_function_call
|
||||
import json
|
||||
|
||||
registry.register(
|
||||
name="__test_good",
|
||||
toolset="test",
|
||||
schema={"name": "__test_good", "description": "test", "parameters": {"type": "object", "properties": {}}},
|
||||
handler=lambda args, **kw: json.dumps({"status": "ok", "data": [1, 2, 3]}),
|
||||
)
|
||||
result = handle_function_call("__test_good", {})
|
||||
parsed = json.loads(result)
|
||||
assert parsed == {"status": "ok", "data": [1, 2, 3]}
|
||||
registry._tools.pop("__test_good", None)
|
||||
|
||||
Reference in New Issue
Block a user