Compare commits
69 Commits
fix/format
...
feat/stati
| Author | SHA1 | Date | |
|---|---|---|---|
| 93a855d4e3 | |||
| a2a40429bd | |||
| ee61c5fa9d | |||
|
|
1fece10569 | ||
| 46668505bc | |||
| cac0c8224e | |||
| f38a64455d | |||
| 1b35a5a0d2 | |||
| 9172131b25 | |||
| 407eab3331 | |||
| cf090a966d | |||
| b65be9b12c | |||
| 3c1cff255e | |||
| 690d100afc | |||
| c6f0831738 | |||
| 30773ac1f9 | |||
| feb24bd08c | |||
| bc55f40505 | |||
| 2adc72335e | |||
| ab32670464 | |||
| bfc0231297 | |||
| cf2b09cf2f | |||
| 719bb537c0 | |||
| 0bcbcf19ac | |||
| 27d2f2ca0e | |||
| 7e7dcfa345 | |||
| ba0e614446 | |||
| 4f5e641c92 | |||
| d61bd141f9 | |||
| a4058af238 | |||
| 08432a5618 | |||
| a875c6ed91 | |||
| 07c5b5b83d | |||
| ba56567631 | |||
| 8ac26f54a5 | |||
| b807972d05 | |||
| 6b5a6db668 | |||
| b702249c12 | |||
|
|
8023c9b8f2 | ||
| 6eeee39c10 | |||
| b2d2d2c650 | |||
| bdd0f2709b | |||
| a9cbf7d69f | |||
| 4cdda8701d | |||
| a80d30b342 | |||
| f098cf8c4a | |||
| 30509b9c7c | |||
| ccaa1cb021 | |||
| c6f2855745 | |||
| 9d180f31cc | |||
| c17f64fa2c | |||
| bc7ffc2166 | |||
|
|
c22cdcaa8e | ||
|
|
ab968e910c | ||
|
|
73984ca72f | ||
| 436c800def | |||
| cb331da4f1 | |||
| fa892bfcb9 | |||
|
|
0b72884750 | ||
| a0ed1e6ff2 | |||
| b5ba272efe | |||
| 2e0dfe27df | |||
| d4cdfdc604 | |||
| e3436e36c3 | |||
| 34e7de6a4c | |||
| dbabe0e6ae | |||
| 517e2c571e | |||
| 0b019327a3 | |||
| 6b0fca6944 |
28
.gitea/workflows/lint.yml
Normal file
28
.gitea/workflows/lint.yml
Normal file
@@ -0,0 +1,28 @@
|
||||
name: Lint
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Check for hardcoded paths
|
||||
run: python3 scripts/lint_hardcoded_paths.py
|
||||
continue-on-error: true
|
||||
|
||||
- name: Check Python syntax
|
||||
run: |
|
||||
find . -name "*.py" -not -path "./.git/*" -not -path "./node_modules/*" | head -100 | xargs python3 -m py_compile || true
|
||||
78
.githooks/pre-commit-hardcoded-path.py
Normal file
78
.githooks/pre-commit-hardcoded-path.py
Normal file
@@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pre-commit hook: Reject hardcoded home-directory paths.
|
||||
|
||||
Install:
|
||||
cp pre-commit-hardcoded-path.py .git/hooks/pre-commit-hardcoded-path
|
||||
chmod +x .git/hooks/pre-commit-hardcoded-path
|
||||
|
||||
Or add to .pre-commit-config.yaml
|
||||
"""
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
PATTERNS = [
|
||||
(r"/Users/[\w.\-]+/", "macOS home directory"),
|
||||
(r"/home/[\w.\-]+/", "Linux home directory"),
|
||||
(r"(?<![\w/])~/", "unexpanded tilde"),
|
||||
]
|
||||
|
||||
NOQA = re.compile(r"#\s*noqa:?\s*hardcoded-path-ok")
|
||||
|
||||
def get_staged_files():
|
||||
result = subprocess.run(
|
||||
["git", "diff", "--cached", "--name-only", "--diff-filter=ACM"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
return [f for f in result.stdout.strip().split("\n") if f.endswith(".py")]
|
||||
|
||||
def check_file(filepath):
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["git", "show", f":{filepath}"],
|
||||
capture_output=True, text=True
|
||||
)
|
||||
content = result.stdout
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
violations = []
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if line.strip().startswith("#"):
|
||||
continue
|
||||
if line.strip().startswith(("import ", "from ")):
|
||||
continue
|
||||
if NOQA.search(line):
|
||||
continue
|
||||
for pattern, desc in PATTERNS:
|
||||
if re.search(pattern, line):
|
||||
violations.append((filepath, i, line.strip(), desc))
|
||||
break
|
||||
return violations
|
||||
|
||||
def main():
|
||||
files = get_staged_files()
|
||||
if not files:
|
||||
sys.exit(0)
|
||||
|
||||
all_violations = []
|
||||
for f in files:
|
||||
all_violations.extend(check_file(f))
|
||||
|
||||
if all_violations:
|
||||
print("ERROR: Hardcoded home directory paths detected:")
|
||||
print()
|
||||
for filepath, line_no, line, desc in all_violations:
|
||||
print(f" {filepath}:{line_no}: {desc}")
|
||||
print(f" {line[:100]}")
|
||||
print()
|
||||
print("Fix: Use $HOME, relative paths, or get_hermes_home().")
|
||||
print("Override: Add '# noqa: hardcoded-path-ok' to the line.")
|
||||
sys.exit(1)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -25,6 +25,10 @@ jobs:
|
||||
- name: Install system dependencies
|
||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||
|
||||
- name: Check for hardcoded paths
|
||||
run: python3 scripts/lint_hardcoded_paths.py || true
|
||||
continue-on-error: true
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||
|
||||
|
||||
273
agent/circuit_breaker.py
Normal file
273
agent/circuit_breaker.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Circuit Breaker for Error Cascading — #885
|
||||
|
||||
P(error | prev was error) = 58.6% vs P(error | prev was success) = 25.2%.
|
||||
That's a 2.33x cascade factor. After 3 consecutive errors, the circuit
|
||||
opens and the agent must take corrective action.
|
||||
|
||||
States:
|
||||
- CLOSED: Normal operation, errors are counted
|
||||
- OPEN: Too many consecutive errors, corrective action required
|
||||
- HALF_OPEN: Testing if errors have cleared
|
||||
|
||||
Usage:
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker
|
||||
|
||||
cb = ToolCircuitBreaker()
|
||||
|
||||
# After each tool call
|
||||
if not cb.record_result(success=True):
|
||||
# Circuit is open — take corrective action
|
||||
cb.get_recovery_action()
|
||||
"""
|
||||
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class CircuitState(Enum):
|
||||
CLOSED = "closed" # Normal operation
|
||||
OPEN = "open" # Too many errors, block execution
|
||||
HALF_OPEN = "half_open" # Testing recovery
|
||||
|
||||
|
||||
@dataclass
|
||||
class CircuitBreaker:
|
||||
"""
|
||||
Generic circuit breaker with configurable thresholds.
|
||||
|
||||
Tracks consecutive errors and opens the circuit when the
|
||||
error streak exceeds the threshold.
|
||||
"""
|
||||
failure_threshold: int = 3
|
||||
recovery_timeout: float = 30.0 # seconds before trying half-open
|
||||
success_threshold: int = 2 # successes needed to close from half-open
|
||||
|
||||
state: CircuitState = field(default=CircuitState.CLOSED, init=False)
|
||||
consecutive_failures: int = field(default=0, init=False)
|
||||
consecutive_successes: int = field(default=0, init=False)
|
||||
last_failure_time: Optional[float] = field(default=None, init=False)
|
||||
total_trips: int = field(default=0, init=False)
|
||||
error_streaks: List[int] = field(default_factory=list, init=False)
|
||||
|
||||
def record_result(self, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if circuit allows execution.
|
||||
|
||||
Returns:
|
||||
True if circuit is CLOSED or HALF_OPEN (execution allowed)
|
||||
False if circuit is OPEN (execution blocked)
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if self.state == CircuitState.OPEN:
|
||||
# Check if recovery timeout has passed
|
||||
if self.last_failure_time and (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True # Allow one test execution
|
||||
return False # Still open
|
||||
|
||||
if success:
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes += 1
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
if self.consecutive_successes >= self.success_threshold:
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_successes = 0
|
||||
|
||||
return True
|
||||
else:
|
||||
self.consecutive_successes = 0
|
||||
self.consecutive_failures += 1
|
||||
self.last_failure_time = now
|
||||
|
||||
if self.state == CircuitState.HALF_OPEN:
|
||||
# Failed during recovery — reopen immediately
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
return False
|
||||
|
||||
if self.consecutive_failures >= self.failure_threshold:
|
||||
self.state = CircuitState.OPEN
|
||||
self.total_trips += 1
|
||||
self.error_streaks.append(self.consecutive_failures)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def can_execute(self) -> bool:
|
||||
"""Check if execution is allowed."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
if self.last_failure_time:
|
||||
now = time.time()
|
||||
if (now - self.last_failure_time) >= self.recovery_timeout:
|
||||
self.state = CircuitState.HALF_OPEN
|
||||
self.consecutive_successes = 0
|
||||
return True
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_state(self) -> Dict[str, Any]:
|
||||
"""Get current circuit state."""
|
||||
return {
|
||||
"state": self.state.value,
|
||||
"consecutive_failures": self.consecutive_failures,
|
||||
"consecutive_successes": self.consecutive_successes,
|
||||
"total_trips": self.total_trips,
|
||||
"max_streak": max(self.error_streaks) if self.error_streaks else 0,
|
||||
"can_execute": self.can_execute(),
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""Reset the circuit breaker."""
|
||||
self.state = CircuitState.CLOSED
|
||||
self.consecutive_failures = 0
|
||||
self.consecutive_successes = 0
|
||||
self.last_failure_time = None
|
||||
|
||||
|
||||
class ToolCircuitBreaker(CircuitBreaker):
|
||||
"""
|
||||
Circuit breaker specifically for tool call error cascading.
|
||||
|
||||
Provides recovery actions when the circuit opens.
|
||||
"""
|
||||
|
||||
# Tools that are most effective at recovery (from audit data)
|
||||
RECOVERY_TOOLS = [
|
||||
"terminal", # Most effective — 2300 recoveries
|
||||
"read_file", # Reset context by reading something
|
||||
"search_files", # Find what went wrong
|
||||
]
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the recommended recovery action when circuit is open.
|
||||
|
||||
Returns dict with action type and details.
|
||||
"""
|
||||
streak = self.consecutive_failures
|
||||
|
||||
if streak >= 9:
|
||||
# After 9 errors: 41/46 recoveries via terminal
|
||||
return {
|
||||
"action": "terminal_only",
|
||||
"reason": f"Error streak of {streak} — terminal is the only reliable recovery",
|
||||
"suggested_tool": "terminal",
|
||||
"suggested_command": "echo 'Resetting context'",
|
||||
"severity": "critical",
|
||||
}
|
||||
elif streak >= 5:
|
||||
return {
|
||||
"action": "switch_tool_type",
|
||||
"reason": f"Error streak of {streak} — switch to a different tool category",
|
||||
"suggested_tools": ["read_file", "search_files", "terminal"],
|
||||
"severity": "high",
|
||||
}
|
||||
elif streak >= self.failure_threshold:
|
||||
return {
|
||||
"action": "ask_user",
|
||||
"reason": f"{streak} consecutive errors — ask user for guidance",
|
||||
"suggested_response": "I'm encountering repeated errors. Would you like me to try a different approach?",
|
||||
"severity": "medium",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Error streak of {streak} — within tolerance",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def should_compress_context(self) -> bool:
|
||||
"""Determine if context compression would help recovery."""
|
||||
return self.consecutive_failures >= 5
|
||||
|
||||
def get_blocked_tool(self) -> Optional[str]:
|
||||
"""Get the tool that should be blocked (if any)."""
|
||||
if self.state == CircuitState.OPEN:
|
||||
return "last_failed_tool"
|
||||
return None
|
||||
|
||||
|
||||
class MultiToolCircuitBreaker:
|
||||
"""
|
||||
Manages per-tool circuit breakers and cross-tool cascade detection.
|
||||
|
||||
When one tool trips its breaker, related tools are also warned.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.breakers: Dict[str, ToolCircuitBreaker] = {}
|
||||
self.global_streak: int = 0
|
||||
self.last_tool: Optional[str] = None
|
||||
self.last_success: bool = True
|
||||
|
||||
def get_breaker(self, tool_name: str) -> ToolCircuitBreaker:
|
||||
"""Get or create a circuit breaker for a tool."""
|
||||
if tool_name not in self.breakers:
|
||||
self.breakers[tool_name] = ToolCircuitBreaker()
|
||||
return self.breakers[tool_name]
|
||||
|
||||
def record_result(self, tool_name: str, success: bool) -> bool:
|
||||
"""
|
||||
Record a tool call result. Returns True if execution should continue.
|
||||
"""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
allowed = breaker.record_result(success)
|
||||
|
||||
# Track global streak
|
||||
if success:
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
else:
|
||||
self.global_streak += 1
|
||||
self.last_success = False
|
||||
|
||||
self.last_tool = tool_name
|
||||
return allowed
|
||||
|
||||
def can_execute(self, tool_name: str) -> bool:
|
||||
"""Check if a specific tool can execute."""
|
||||
breaker = self.get_breaker(tool_name)
|
||||
return breaker.can_execute()
|
||||
|
||||
def get_global_state(self) -> Dict[str, Any]:
|
||||
"""Get overall circuit breaker state."""
|
||||
return {
|
||||
"global_streak": self.global_streak,
|
||||
"last_tool": self.last_tool,
|
||||
"last_success": self.last_success,
|
||||
"tool_states": {
|
||||
name: breaker.get_state()
|
||||
for name, breaker in self.breakers.items()
|
||||
if breaker.consecutive_failures > 0 or breaker.total_trips > 0
|
||||
},
|
||||
"any_open": any(b.state == CircuitState.OPEN for b in self.breakers.values()),
|
||||
}
|
||||
|
||||
def get_recovery_action(self) -> Dict[str, Any]:
|
||||
"""Get recovery action based on global state."""
|
||||
if self.global_streak == 0:
|
||||
return {"action": "continue", "reason": "No errors"}
|
||||
|
||||
# Find the breaker with the worst streak
|
||||
worst = max(self.breakers.values(), key=lambda b: b.consecutive_failures, default=None)
|
||||
if worst and worst.consecutive_failures > 0:
|
||||
return worst.get_recovery_action()
|
||||
|
||||
return {
|
||||
"action": "continue",
|
||||
"reason": f"Global streak: {self.global_streak}",
|
||||
"severity": "low",
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""Reset all circuit breakers."""
|
||||
for breaker in self.breakers.values():
|
||||
breaker.reset()
|
||||
self.global_streak = 0
|
||||
self.last_success = True
|
||||
148
agent/context_budget.py
Normal file
148
agent/context_budget.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""
|
||||
Context Budget Tracker - Prevent context window overflow
|
||||
|
||||
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
|
||||
Auto-checkpoint at 85%%. Pre-flight token estimation.
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
|
||||
CHARS_PER_TOKEN = 4
|
||||
|
||||
THRESHOLD_WARNING = 0.70
|
||||
THRESHOLD_CRITICAL = 0.85
|
||||
THRESHOLD_DANGER = 0.95
|
||||
|
||||
|
||||
class ContextBudget:
|
||||
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
|
||||
used_tokens: int = 0, reserved_tokens: int = 2000):
|
||||
self.context_limit = context_limit
|
||||
self.system_tokens = system_tokens
|
||||
self.used_tokens = used_tokens
|
||||
self.reserved_tokens = reserved_tokens
|
||||
|
||||
@property
|
||||
def total_used(self) -> int:
|
||||
return self.system_tokens + self.used_tokens
|
||||
|
||||
@property
|
||||
def available(self) -> int:
|
||||
return max(0, self.context_limit - self.reserved_tokens)
|
||||
|
||||
@property
|
||||
def remaining(self) -> int:
|
||||
return max(0, self.available - self.total_used)
|
||||
|
||||
@property
|
||||
def utilization(self) -> float:
|
||||
return self.total_used / self.available if self.available > 0 else 1.0
|
||||
|
||||
|
||||
def estimate_tokens(text: str) -> int:
|
||||
return len(text) // CHARS_PER_TOKEN if text else 0
|
||||
|
||||
|
||||
def estimate_messages_tokens(messages: List[Dict]) -> int:
|
||||
total = 0
|
||||
for msg in messages:
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
total += estimate_tokens(content)
|
||||
if msg.get("tool_calls"):
|
||||
total += 100
|
||||
return total
|
||||
|
||||
|
||||
class ContextBudgetTracker:
|
||||
def __init__(self, context_limit: int = 128000, session_id: str = ""):
|
||||
self.budget = ContextBudget(context_limit=context_limit)
|
||||
self.session_id = session_id
|
||||
self._checkpointed = False
|
||||
self._warnings_given = set()
|
||||
|
||||
def update_from_messages(self, messages: List[Dict]):
|
||||
self.budget.used_tokens = estimate_messages_tokens(messages)
|
||||
|
||||
def can_fit(self, additional_tokens: int) -> bool:
|
||||
return self.budget.remaining >= additional_tokens
|
||||
|
||||
def preflight_check(self, text: str) -> Tuple[bool, str]:
|
||||
tokens = estimate_tokens(text)
|
||||
if not self.can_fit(tokens):
|
||||
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
|
||||
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
|
||||
if would_util >= THRESHOLD_DANGER:
|
||||
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
|
||||
if would_util >= THRESHOLD_CRITICAL:
|
||||
return True, f"Warning: will reach {would_util:.0%%} capacity."
|
||||
return True, ""
|
||||
|
||||
def get_warning(self) -> Optional[str]:
|
||||
util = self.budget.utilization
|
||||
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
|
||||
self._warnings_given.add("danger")
|
||||
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
|
||||
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
|
||||
self._warnings_given.add("critical")
|
||||
self._auto_checkpoint()
|
||||
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
|
||||
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
|
||||
self._warnings_given.add("warning")
|
||||
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
|
||||
return None
|
||||
|
||||
def _auto_checkpoint(self):
|
||||
if self._checkpointed or not self.session_id:
|
||||
return
|
||||
try:
|
||||
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = CHECKPOINT_DIR / f"{self.session_id}.json"
|
||||
path.write_text(json.dumps({
|
||||
"session_id": self.session_id,
|
||||
"timestamp": time.time(),
|
||||
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
|
||||
}, indent=2))
|
||||
self._checkpointed = True
|
||||
logger.info("Auto-checkpoint saved: %s", path)
|
||||
except Exception as e:
|
||||
logger.error("Auto-checkpoint failed: %s", e)
|
||||
|
||||
def get_status_line(self) -> str:
|
||||
util = self.budget.utilization
|
||||
remaining = self.budget.remaining
|
||||
if util >= THRESHOLD_DANGER:
|
||||
return f"RED {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_CRITICAL:
|
||||
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
|
||||
elif util >= THRESHOLD_WARNING:
|
||||
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
|
||||
return f"GREEN {util:.0%%} used ({remaining:,} left)"
|
||||
|
||||
|
||||
_tracker = None
|
||||
|
||||
def get_tracker(context_limit=128000, session_id=""):
|
||||
global _tracker
|
||||
if _tracker is None:
|
||||
_tracker = ContextBudgetTracker(context_limit, session_id)
|
||||
return _tracker
|
||||
|
||||
def check_context_budget(messages, context_limit=128000):
|
||||
tracker = get_tracker(context_limit)
|
||||
tracker.update_from_messages(messages)
|
||||
return tracker.get_warning()
|
||||
|
||||
def preflight_token_check(text):
|
||||
tracker = get_tracker()
|
||||
return tracker.preflight_check(text)
|
||||
149
agent/crisis_resources.py
Normal file
149
agent/crisis_resources.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
988 Suicide & Crisis Lifeline Integration (#673).
|
||||
|
||||
When crisis is detected, provides immediate access to help:
|
||||
- Phone: 988 (call or text)
|
||||
- Text: Text HOME to 988
|
||||
- Chat: 988lifeline.org/chat
|
||||
- Spanish: 1-888-628-9454
|
||||
- Emergency: 911
|
||||
|
||||
This module provides the resource data. agent/crisis_protocol.py
|
||||
handles detection. This module formats the resources for display.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrisisResource:
|
||||
"""A crisis support contact method."""
|
||||
name: str
|
||||
contact: str
|
||||
description: str
|
||||
url: str = ""
|
||||
available: str = "24/7"
|
||||
language: str = "English"
|
||||
|
||||
|
||||
# 988 Suicide & Crisis Lifeline — all channels
|
||||
LIFELINE_988 = CrisisResource(
|
||||
name="988 Suicide and Crisis Lifeline",
|
||||
contact="Call or text 988",
|
||||
description="Free, confidential support for people in suicidal crisis or emotional distress.",
|
||||
url="https://988lifeline.org",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_TEXT = CrisisResource(
|
||||
name="988 Crisis Text Line",
|
||||
contact="Text HOME to 988",
|
||||
description="Free, 24/7 crisis support via text message.",
|
||||
url="",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_CHAT = CrisisResource(
|
||||
name="988 Lifeline Chat",
|
||||
contact="988lifeline.org/chat",
|
||||
description="Free, confidential online chat with a trained crisis counselor.",
|
||||
url="https://988lifeline.org/chat",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
LIFELINE_988_SPANISH = CrisisResource(
|
||||
name="988 Lifeline (Spanish)",
|
||||
contact="1-888-628-9454",
|
||||
description="Línea de prevención del suicidio en español.",
|
||||
url="https://988lifeline.org/help-yourself/en-espanol/",
|
||||
available="24/7",
|
||||
language="Spanish",
|
||||
)
|
||||
|
||||
CRISIS_TEXT_LINE = CrisisResource(
|
||||
name="Crisis Text Line",
|
||||
contact="Text HOME to 741741",
|
||||
description="Free, 24/7 crisis support via text message.",
|
||||
url="https://www.crisistextline.org",
|
||||
available="24/7",
|
||||
language="English",
|
||||
)
|
||||
|
||||
EMERGENCY_911 = CrisisResource(
|
||||
name="Emergency Services",
|
||||
contact="911",
|
||||
description="Immediate danger — police, fire, ambulance.",
|
||||
url="",
|
||||
available="24/7",
|
||||
language="Any",
|
||||
)
|
||||
|
||||
# All resources in priority order
|
||||
ALL_RESOURCES: List[CrisisResource] = [
|
||||
EMERGENCY_911,
|
||||
LIFELINE_988,
|
||||
LIFELINE_988_TEXT,
|
||||
LIFELINE_988_CHAT,
|
||||
CRISIS_TEXT_LINE,
|
||||
LIFELINE_988_SPANISH,
|
||||
]
|
||||
|
||||
|
||||
def get_crisis_resources(language: str = None) -> List[CrisisResource]:
|
||||
"""Get crisis resources, optionally filtered by language.
|
||||
|
||||
Args:
|
||||
language: Filter by language ("English", "Spanish", or None for all)
|
||||
|
||||
Returns:
|
||||
List of CrisisResource objects
|
||||
"""
|
||||
if language:
|
||||
return [r for r in ALL_RESOURCES if r.language.lower() == language.lower()]
|
||||
return ALL_RESOURCES
|
||||
|
||||
|
||||
def format_crisis_resources(resources: List[CrisisResource] = None) -> str:
|
||||
"""Format crisis resources as a user-facing message.
|
||||
|
||||
Args:
|
||||
resources: List of resources to format. Defaults to all resources.
|
||||
|
||||
Returns:
|
||||
Formatted string suitable for displaying to a user in crisis.
|
||||
"""
|
||||
if resources is None:
|
||||
resources = ALL_RESOURCES
|
||||
|
||||
lines = ["**Please reach out — help is available right now:**
|
||||
"]
|
||||
|
||||
for r in resources:
|
||||
if r.url:
|
||||
lines.append(f"- **{r.name}:** {r.contact} ({r.url})")
|
||||
else:
|
||||
lines.append(f"- **{r.name}:** {r.contact}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("All services are free, confidential, and available 24/7.")
|
||||
lines.append("You are not alone.")
|
||||
|
||||
return "
|
||||
".join(lines)
|
||||
|
||||
|
||||
def get_immediate_help_message() -> str:
|
||||
"""Get the most urgent crisis help message.
|
||||
|
||||
Used when crisis is detected at CRITICAL level.
|
||||
"""
|
||||
return (
|
||||
"If you are in immediate danger, call **911** right now.
|
||||
|
||||
"
|
||||
+ format_crisis_resources()
|
||||
)
|
||||
262
agent/profile_isolation.py
Normal file
262
agent/profile_isolation.py
Normal file
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
Profile Session Isolation — #891
|
||||
|
||||
Tags sessions with their originating profile and provides
|
||||
filtered access so profiles cannot see each other's data.
|
||||
|
||||
Current state: All sessions share one state.db with no profile tag.
|
||||
This module adds profile tagging and filtered queries.
|
||||
|
||||
Usage:
|
||||
from agent.profile_isolation import tag_session, get_profile_sessions, get_active_profile
|
||||
|
||||
# Tag a new session with the current profile
|
||||
tag_session(session_id, profile_name)
|
||||
|
||||
# Get sessions for a specific profile
|
||||
sessions = get_profile_sessions("sprint")
|
||||
|
||||
# Get current active profile
|
||||
profile = get_active_profile()
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from datetime import datetime, timezone
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes")))
|
||||
SESSIONS_DB = HERMES_HOME / "sessions" / "state.db"
|
||||
PROFILE_TAGS_FILE = HERMES_HOME / "profile_session_tags.json"
|
||||
|
||||
|
||||
def get_active_profile() -> str:
|
||||
"""Get the currently active profile name."""
|
||||
config_path = HERMES_HOME / "config.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
return cfg.get("active_profile", "default")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check environment
|
||||
return os.getenv("HERMES_PROFILE", "default")
|
||||
|
||||
|
||||
def _load_tags() -> Dict[str, str]:
|
||||
"""Load session-to-profile mapping."""
|
||||
if not PROFILE_TAGS_FILE.exists():
|
||||
return {}
|
||||
try:
|
||||
with open(PROFILE_TAGS_FILE) as f:
|
||||
return json.load(f)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _save_tags(tags: Dict[str, str]):
|
||||
"""Save session-to-profile mapping."""
|
||||
PROFILE_TAGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(PROFILE_TAGS_FILE, "w") as f:
|
||||
json.dump(tags, f, indent=2)
|
||||
|
||||
|
||||
def tag_session(session_id: str, profile: Optional[str] = None) -> str:
|
||||
"""
|
||||
Tag a session with its originating profile.
|
||||
|
||||
Returns the profile name used.
|
||||
"""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
tags = _load_tags()
|
||||
tags[session_id] = profile
|
||||
_save_tags(tags)
|
||||
|
||||
# Also tag in SQLite if available
|
||||
_tag_session_in_db(session_id, profile)
|
||||
|
||||
return profile
|
||||
|
||||
|
||||
def _tag_session_in_db(session_id: str, profile: str):
|
||||
"""Add profile tag to SQLite session store."""
|
||||
if not SESSIONS_DB.exists():
|
||||
return
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check if sessions table has profile column
|
||||
cursor.execute("PRAGMA table_info(sessions)")
|
||||
columns = [row[1] for row in cursor.fetchall()]
|
||||
|
||||
if "profile" not in columns:
|
||||
# Add profile column
|
||||
cursor.execute("ALTER TABLE sessions ADD COLUMN profile TEXT DEFAULT 'default'")
|
||||
|
||||
# Update the session's profile
|
||||
cursor.execute(
|
||||
"UPDATE sessions SET profile = ? WHERE session_id = ?",
|
||||
(profile, session_id)
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass # SQLite might not be available or schema differs
|
||||
|
||||
|
||||
def get_session_profile(session_id: str) -> Optional[str]:
|
||||
"""Get the profile that owns a session."""
|
||||
# Check JSON tags first
|
||||
tags = _load_tags()
|
||||
if session_id in tags:
|
||||
return tags[session_id]
|
||||
|
||||
# Check SQLite
|
||||
if SESSIONS_DB.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT profile FROM sessions WHERE session_id = ?",
|
||||
(session_id,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
conn.close()
|
||||
if row:
|
||||
return row[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_profile_sessions(
|
||||
profile: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get sessions belonging to a specific profile.
|
||||
|
||||
Returns list of session dicts.
|
||||
"""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
sessions = []
|
||||
|
||||
# Get from JSON tags
|
||||
tags = _load_tags()
|
||||
tagged_sessions = [sid for sid, p in tags.items() if p == profile]
|
||||
|
||||
# Get from SQLite with profile filter
|
||||
if SESSIONS_DB.exists():
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Try profile column first
|
||||
try:
|
||||
cursor.execute(
|
||||
"SELECT * FROM sessions WHERE profile = ? ORDER BY updated_at DESC LIMIT ?",
|
||||
(profile, limit)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
sessions.append(dict(row))
|
||||
except Exception:
|
||||
# Fallback: filter by tagged session IDs
|
||||
if tagged_sessions:
|
||||
placeholders = ",".join("?" * len(tagged_sessions[:limit]))
|
||||
cursor.execute(
|
||||
f"SELECT * FROM sessions WHERE session_id IN ({placeholders}) ORDER BY updated_at DESC LIMIT ?",
|
||||
(*tagged_sessions[:limit], limit)
|
||||
)
|
||||
for row in cursor.fetchall():
|
||||
sessions.append(dict(row))
|
||||
|
||||
conn.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return sessions[:limit]
|
||||
|
||||
|
||||
def filter_sessions_by_profile(
|
||||
sessions: List[Dict[str, Any]],
|
||||
profile: Optional[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter a list of sessions to only include those belonging to a profile."""
|
||||
if profile is None:
|
||||
profile = get_active_profile()
|
||||
|
||||
tags = _load_tags()
|
||||
filtered = []
|
||||
|
||||
for session in sessions:
|
||||
sid = session.get("session_id") or session.get("id")
|
||||
if not sid:
|
||||
continue
|
||||
|
||||
# Check tag
|
||||
session_profile = tags.get(sid)
|
||||
if session_profile is None:
|
||||
# Check SQLite
|
||||
session_profile = get_session_profile(sid)
|
||||
|
||||
if session_profile == profile or session_profile is None:
|
||||
filtered.append(session)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
def get_profile_stats() -> Dict[str, Any]:
|
||||
"""Get statistics about profile session distribution."""
|
||||
tags = _load_tags()
|
||||
|
||||
profile_counts = {}
|
||||
for sid, profile in tags.items():
|
||||
profile_counts[profile] = profile_counts.get(profile, 0) + 1
|
||||
|
||||
total_tagged = len(tags)
|
||||
profiles = list(profile_counts.keys())
|
||||
|
||||
return {
|
||||
"total_tagged_sessions": total_tagged,
|
||||
"profiles": profiles,
|
||||
"profile_counts": profile_counts,
|
||||
"active_profile": get_active_profile(),
|
||||
}
|
||||
|
||||
|
||||
def audit_untagged_sessions() -> List[str]:
|
||||
"""Find sessions without a profile tag."""
|
||||
if not SESSIONS_DB.exists():
|
||||
return []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all session IDs
|
||||
cursor.execute("SELECT session_id FROM sessions")
|
||||
all_sessions = {row[0] for row in cursor.fetchall()}
|
||||
conn.close()
|
||||
|
||||
# Get tagged sessions
|
||||
tags = _load_tags()
|
||||
tagged = set(tags.keys())
|
||||
|
||||
# Return untagged
|
||||
return list(all_sessions - tagged)
|
||||
except Exception:
|
||||
return []
|
||||
146
agent/provider_preflight.py
Normal file
146
agent/provider_preflight.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Provider Preflight — Poka-yoke validation of provider/model config.
|
||||
|
||||
Validates provider and model configuration before session start.
|
||||
Prevents wasted context on misconfigured providers.
|
||||
|
||||
Usage:
|
||||
from agent.provider_preflight import preflight_check
|
||||
result = preflight_check(provider="openrouter", model="xiaomi/mimo-v2-pro")
|
||||
if not result["valid"]:
|
||||
print(result["error"])
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Provider -> required env var
|
||||
PROVIDER_KEYS = {
|
||||
"openrouter": "OPENROUTER_API_KEY",
|
||||
"anthropic": "ANTHROPIC_API_KEY",
|
||||
"openai": "OPENAI_API_KEY",
|
||||
"nous": "NOUS_API_KEY",
|
||||
"ollama": None, # Local, no key needed
|
||||
"local": None,
|
||||
}
|
||||
|
||||
|
||||
def check_provider_key(provider: str) -> Dict[str, Any]:
|
||||
"""Check if provider has a valid API key configured."""
|
||||
provider_lower = provider.lower().strip()
|
||||
|
||||
env_var = None
|
||||
for known, key in PROVIDER_KEYS.items():
|
||||
if known in provider_lower:
|
||||
env_var = key
|
||||
break
|
||||
|
||||
if env_var is None:
|
||||
# Unknown provider — assume OK (custom/local)
|
||||
return {"valid": True, "provider": provider, "key_status": "unknown"}
|
||||
|
||||
if env_var is None:
|
||||
# Local provider, no key needed
|
||||
return {"valid": True, "provider": provider, "key_status": "not_required"}
|
||||
|
||||
key_value = os.getenv(env_var, "").strip()
|
||||
if not key_value:
|
||||
return {
|
||||
"valid": False,
|
||||
"provider": provider,
|
||||
"key_status": "missing",
|
||||
"error": f"{env_var} is not set. Provider '{provider}' will fail.",
|
||||
"fix": f"Set {env_var} in ~/.hermes/.env",
|
||||
}
|
||||
|
||||
if len(key_value) < 10:
|
||||
return {
|
||||
"valid": False,
|
||||
"provider": provider,
|
||||
"key_status": "too_short",
|
||||
"error": f"{env_var} is suspiciously short ({len(key_value)} chars). May be invalid.",
|
||||
"fix": f"Verify {env_var} value in ~/.hermes/.env",
|
||||
}
|
||||
|
||||
return {"valid": True, "provider": provider, "key_status": "set"}
|
||||
|
||||
|
||||
def check_model_availability(model: str, provider: str) -> Dict[str, Any]:
|
||||
"""Check if model is likely available for provider."""
|
||||
if not model:
|
||||
return {"valid": False, "error": "No model specified"}
|
||||
|
||||
# Basic sanity checks
|
||||
model_lower = model.lower()
|
||||
|
||||
# Anthropic models should use anthropic provider
|
||||
if "claude" in model_lower and "anthropic" not in provider.lower():
|
||||
return {
|
||||
"valid": True, # Allow but warn
|
||||
"warning": f"Model '{model}' usually runs on Anthropic provider, not '{provider}'",
|
||||
}
|
||||
|
||||
# Ollama models
|
||||
ollama_indicators = ["llama", "mistral", "qwen", "gemma", "phi", "hermes"]
|
||||
if any(x in model_lower for x in ollama_indicators) and ":" not in model:
|
||||
return {
|
||||
"valid": True,
|
||||
"warning": f"Model '{model}' may need a version tag for Ollama (e.g., {model}:latest)",
|
||||
}
|
||||
|
||||
return {"valid": True}
|
||||
|
||||
|
||||
def preflight_check(
|
||||
provider: str = "",
|
||||
model: str = "",
|
||||
fallback_provider: str = "",
|
||||
fallback_model: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Full pre-flight check for provider/model configuration.
|
||||
|
||||
Returns:
|
||||
Dict with valid (bool), errors (list), warnings (list).
|
||||
"""
|
||||
errors = []
|
||||
warnings = []
|
||||
|
||||
# Check primary provider
|
||||
if provider:
|
||||
result = check_provider_key(provider)
|
||||
if not result["valid"]:
|
||||
errors.append(result.get("error", f"Provider {provider} invalid"))
|
||||
|
||||
# Check primary model
|
||||
if model:
|
||||
result = check_model_availability(model, provider)
|
||||
if not result["valid"]:
|
||||
errors.append(result.get("error", f"Model {model} invalid"))
|
||||
elif result.get("warning"):
|
||||
warnings.append(result["warning"])
|
||||
|
||||
# Check fallback
|
||||
if fallback_provider:
|
||||
result = check_provider_key(fallback_provider)
|
||||
if not result["valid"]:
|
||||
warnings.append(f"Fallback provider {fallback_provider} also invalid: {result.get('error','')}")
|
||||
|
||||
if fallback_model:
|
||||
result = check_model_availability(fallback_model, fallback_provider)
|
||||
if not result["valid"]:
|
||||
warnings.append(f"Fallback model {fallback_model} invalid")
|
||||
elif result.get("warning"):
|
||||
warnings.append(result["warning"])
|
||||
|
||||
return {
|
||||
"valid": len(errors) == 0,
|
||||
"errors": errors,
|
||||
"warnings": warnings,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
}
|
||||
146
agent/time_aware_routing.py
Normal file
146
agent/time_aware_routing.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Time-aware model routing for cron jobs.
|
||||
|
||||
Routes cron tasks to more capable models during off-hours when the user
|
||||
is not present to correct errors. Reduces error rates during high-error
|
||||
time windows (e.g., 18:00 evening batches).
|
||||
|
||||
Usage:
|
||||
from agent.time_aware_routing import resolve_time_aware_model
|
||||
model = resolve_time_aware_model(base_model="mimo-v2-pro", is_cron=True)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
# Error rate data from empirical audit (2026-04-12)
|
||||
# Higher error rates during these hours suggest routing to better models
|
||||
_HIGH_ERROR_HOURS = {
|
||||
18: 9.4, # 18:00 — 9.4% error rate (evening cron batches)
|
||||
19: 8.1,
|
||||
20: 7.5,
|
||||
21: 6.8,
|
||||
22: 6.2,
|
||||
23: 5.9,
|
||||
0: 5.5,
|
||||
1: 5.2,
|
||||
}
|
||||
|
||||
# Low error hours — default model is fine
|
||||
_LOW_ERROR_HOURS = set(range(6, 18)) # 06:00-17:59
|
||||
|
||||
# Default fallback models by time zone
|
||||
_DEFAULT_STRONG_MODEL = os.getenv("CRON_STRONG_MODEL", "xiaomi/mimo-v2-pro")
|
||||
_DEFAULT_CHEAP_MODEL = os.getenv("CRON_CHEAP_MODEL", "qwen2.5:7b")
|
||||
_ERROR_THRESHOLD = float(os.getenv("CRON_ERROR_THRESHOLD", "6.0")) # % error rate
|
||||
|
||||
|
||||
@dataclass
|
||||
class RoutingDecision:
|
||||
"""Result of time-aware routing."""
|
||||
model: str
|
||||
provider: str
|
||||
reason: str
|
||||
hour: int
|
||||
error_rate: float
|
||||
is_off_hours: bool
|
||||
|
||||
|
||||
def get_hour_error_rate(hour: int) -> float:
|
||||
"""Get expected error rate for a given hour (0-23)."""
|
||||
return _HIGH_ERROR_HOURS.get(hour, 4.0) # Default 4% for unlisted hours
|
||||
|
||||
|
||||
def is_off_hours(hour: int) -> bool:
|
||||
"""Check if hour is considered off-hours (higher error rates)."""
|
||||
return hour not in _LOW_ERROR_HOURS
|
||||
|
||||
|
||||
def resolve_time_aware_model(
|
||||
base_model: str = "",
|
||||
base_provider: str = "",
|
||||
is_cron: bool = False,
|
||||
hour: Optional[int] = None,
|
||||
) -> RoutingDecision:
|
||||
"""Resolve model based on time of day and task type.
|
||||
|
||||
During off-hours (evening/night), routes to stronger models for cron
|
||||
jobs to compensate for lack of human oversight.
|
||||
|
||||
Args:
|
||||
base_model: The model that would normally be used.
|
||||
base_provider: The provider for the base model.
|
||||
is_cron: Whether this is a cron job (vs interactive session).
|
||||
hour: Override hour (for testing). Defaults to current hour.
|
||||
|
||||
Returns:
|
||||
RoutingDecision with model, provider, and reasoning.
|
||||
"""
|
||||
if hour is None:
|
||||
hour = time.localtime().tm_hour
|
||||
|
||||
error_rate = get_hour_error_rate(hour)
|
||||
off_hours = is_off_hours(hour)
|
||||
|
||||
# Interactive sessions always use the base model (user can correct errors)
|
||||
if not is_cron:
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason="Interactive session — user can correct errors",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=off_hours,
|
||||
)
|
||||
|
||||
# Cron jobs during low-error hours: use base model
|
||||
if not off_hours and error_rate < _ERROR_THRESHOLD:
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason=f"Low-error hours ({hour}:00, {error_rate}% expected)",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=False,
|
||||
)
|
||||
|
||||
# Cron jobs during high-error hours: upgrade to stronger model
|
||||
if error_rate >= _ERROR_THRESHOLD:
|
||||
return RoutingDecision(
|
||||
model=_DEFAULT_STRONG_MODEL,
|
||||
provider="nous",
|
||||
reason=f"High-error hours ({hour}:00, {error_rate}% expected) — using stronger model",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=True,
|
||||
)
|
||||
|
||||
# Off-hours but low error: use base model
|
||||
return RoutingDecision(
|
||||
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||
provider=base_provider,
|
||||
reason=f"Off-hours but low error ({hour}:00, {error_rate}%)",
|
||||
hour=hour,
|
||||
error_rate=error_rate,
|
||||
is_off_hours=off_hours,
|
||||
)
|
||||
|
||||
|
||||
def get_routing_report() -> str:
|
||||
"""Get a report of time-based routing decisions for the next 24 hours."""
|
||||
lines = ["Time-Aware Model Routing (24h forecast)", "=" * 40, ""]
|
||||
lines.append(f"Error threshold: {_ERROR_THRESHOLD}%")
|
||||
lines.append(f"Strong model: {_DEFAULT_STRONG_MODEL}")
|
||||
lines.append(f"Cheap model: {_DEFAULT_CHEAP_MODEL}")
|
||||
lines.append("")
|
||||
|
||||
for h in range(24):
|
||||
decision = resolve_time_aware_model(is_cron=True, hour=h)
|
||||
icon = "\U0001f7e2" if decision.model == _DEFAULT_CHEAP_MODEL else "\U0001f534"
|
||||
lines.append(f" {h:02d}:00 {icon} {decision.model:25s} ({decision.error_rate}% error)")
|
||||
|
||||
return "\n".join(lines)
|
||||
316
agent/token_budget.py
Normal file
316
agent/token_budget.py
Normal file
@@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Token Budget — Poka-yoke guard against silent context overflow.
|
||||
|
||||
Progressive warning system with circuit breakers:
|
||||
- 60%: WARNING — log + suggest summarization
|
||||
- 80%: CAUTION — auto-compress, drop raw tool outputs
|
||||
- 90%: CRITICAL — block verbose tool calls, force wrap-up
|
||||
- 95%: STOP — graceful session termination with summary
|
||||
|
||||
Also provides tool output budgeting to truncate before overflow.
|
||||
|
||||
Usage:
|
||||
from agent.token_budget import TokenBudget
|
||||
|
||||
budget = TokenBudget(context_length=128_000)
|
||||
budget.update(8000) # from API response prompt_tokens
|
||||
|
||||
status = budget.check() # returns BudgetStatus with level + message
|
||||
budget.should_block_tools() # True at 90%+
|
||||
budget.should_terminate() # True at 95%+
|
||||
|
||||
# Tool output budgeting
|
||||
remaining = budget.tool_output_budget()
|
||||
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ── Thresholds ────────────────────────────────────────────────────────
|
||||
|
||||
WARN_PERCENT = 0.60
|
||||
CAUTION_PERCENT = 0.80
|
||||
CRITICAL_PERCENT = 0.90
|
||||
STOP_PERCENT = 0.95
|
||||
|
||||
# Reserve 5% of context for system prompt, response, and overhead
|
||||
RESPONSE_RESERVE_RATIO = 0.05
|
||||
|
||||
# Max tool output chars at each level
|
||||
TOOL_OUTPUT_BUDGETS = {
|
||||
"NORMAL": 50_000,
|
||||
"WARNING": 20_000,
|
||||
"CAUTION": 8_000,
|
||||
"CRITICAL": 2_000,
|
||||
"STOP": 500,
|
||||
}
|
||||
|
||||
|
||||
class BudgetLevel(Enum):
|
||||
NORMAL = "NORMAL"
|
||||
WARNING = "WARNING"
|
||||
CAUTION = "CAUTION"
|
||||
CRITICAL = "CRITICAL"
|
||||
STOP = "STOP"
|
||||
|
||||
@property
|
||||
def percent_threshold(self) -> float:
|
||||
return {
|
||||
BudgetLevel.NORMAL: 0.0,
|
||||
BudgetLevel.WARNING: WARN_PERCENT,
|
||||
BudgetLevel.CAUTION: CAUTION_PERCENT,
|
||||
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
|
||||
BudgetLevel.STOP: STOP_PERCENT,
|
||||
}[self]
|
||||
|
||||
@property
|
||||
def emoji(self) -> str:
|
||||
return {
|
||||
BudgetLevel.NORMAL: "",
|
||||
BudgetLevel.WARNING: "\u26a0\ufe0f",
|
||||
BudgetLevel.CAUTION: "\U0001f525",
|
||||
BudgetLevel.CRITICAL: "\U0001f6d1",
|
||||
BudgetLevel.STOP: "\U0001f6d1",
|
||||
}[self]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BudgetStatus:
|
||||
"""Current token budget status."""
|
||||
level: BudgetLevel
|
||||
tokens_used: int
|
||||
context_length: int
|
||||
percent_used: float
|
||||
tokens_remaining: int
|
||||
message: str = ""
|
||||
should_compress: bool = False
|
||||
should_block_tools: bool = False
|
||||
should_terminate: bool = False
|
||||
|
||||
def to_indicator(self) -> str:
|
||||
"""Compact status indicator for CLI display."""
|
||||
pct = int(self.percent_used * 100)
|
||||
if self.level == BudgetLevel.NORMAL:
|
||||
return f"[{pct}%]"
|
||||
return f"{self.level.emoji} [{pct}%]"
|
||||
|
||||
def to_bar(self, width: int = 10) -> str:
|
||||
"""Visual progress bar."""
|
||||
filled = int(width * self.percent_used)
|
||||
bar = "\u2588" * filled + "\u2591" * (width - filled)
|
||||
color = self._bar_color()
|
||||
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
|
||||
|
||||
def _bar_color(self) -> str:
|
||||
if self.level == BudgetLevel.STOP:
|
||||
return "\033[41m" # red bg
|
||||
if self.level == BudgetLevel.CRITICAL:
|
||||
return "\033[31m" # red
|
||||
if self.level == BudgetLevel.CAUTION:
|
||||
return "\033[33m" # yellow
|
||||
if self.level == BudgetLevel.WARNING:
|
||||
return "\033[33m" # yellow
|
||||
return "\033[32m" # green
|
||||
|
||||
|
||||
class TokenBudget:
|
||||
"""
|
||||
Progressive token budget tracker with poka-yoke circuit breakers.
|
||||
|
||||
Tracks cumulative token usage against a context length and triggers
|
||||
escalating actions at each threshold.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
context_length: int,
|
||||
warn_percent: float = WARN_PERCENT,
|
||||
caution_percent: float = CAUTION_PERCENT,
|
||||
critical_percent: float = CRITICAL_PERCENT,
|
||||
stop_percent: float = STOP_PERCENT,
|
||||
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
|
||||
):
|
||||
self.context_length = context_length
|
||||
self.warn_threshold = int(context_length * warn_percent)
|
||||
self.caution_threshold = int(context_length * caution_percent)
|
||||
self.critical_threshold = int(context_length * critical_percent)
|
||||
self.stop_threshold = int(context_length * stop_percent)
|
||||
self.response_reserve = int(context_length * response_reserve_ratio)
|
||||
|
||||
self.tokens_used = 0
|
||||
self.completions_tokens = 0
|
||||
self.total_tool_output_chars = 0
|
||||
self._level = BudgetLevel.NORMAL
|
||||
self._history: list[int] = []
|
||||
|
||||
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
|
||||
"""Update budget from API response usage."""
|
||||
self.tokens_used = prompt_tokens
|
||||
self.completions_tokens = completion_tokens
|
||||
self._history.append(prompt_tokens)
|
||||
return self.check()
|
||||
|
||||
def check(self) -> BudgetStatus:
|
||||
"""Evaluate current budget level and return status."""
|
||||
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
|
||||
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||
|
||||
# Determine level
|
||||
if pct >= STOP_PERCENT:
|
||||
level = BudgetLevel.STOP
|
||||
elif pct >= CRITICAL_PERCENT:
|
||||
level = BudgetLevel.CRITICAL
|
||||
elif pct >= CAUTION_PERCENT:
|
||||
level = BudgetLevel.CAUTION
|
||||
elif pct >= WARN_PERCENT:
|
||||
level = BudgetLevel.WARNING
|
||||
else:
|
||||
level = BudgetLevel.NORMAL
|
||||
|
||||
# Log transitions (don\'t log every check)
|
||||
if level != self._level:
|
||||
self._log_transition(level, pct)
|
||||
self._level = level
|
||||
|
||||
messages = {
|
||||
BudgetLevel.NORMAL: "",
|
||||
BudgetLevel.WARNING: (
|
||||
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
|
||||
),
|
||||
BudgetLevel.CAUTION: (
|
||||
f"Context at {int(pct*100)}%. Auto-compressing. "
|
||||
f"Tool outputs will be truncated."
|
||||
),
|
||||
BudgetLevel.CRITICAL: (
|
||||
f"Context at {int(pct*100)}%. Verbose tools blocked. "
|
||||
f"Session approaching limit — please wrap up."
|
||||
),
|
||||
BudgetLevel.STOP: (
|
||||
f"Context at {int(pct*100)}%. Session must terminate. "
|
||||
f"Saving summary before shutdown."
|
||||
),
|
||||
}
|
||||
|
||||
return BudgetStatus(
|
||||
level=level,
|
||||
tokens_used=self.tokens_used,
|
||||
context_length=self.context_length,
|
||||
percent_used=pct,
|
||||
tokens_remaining=remaining,
|
||||
message=messages[level],
|
||||
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||
should_terminate=level == BudgetLevel.STOP,
|
||||
)
|
||||
|
||||
def should_compress(self) -> bool:
|
||||
"""True at 80%+ — auto-compression should trigger."""
|
||||
return self.tokens_used >= self.caution_threshold
|
||||
|
||||
def should_block_tools(self) -> bool:
|
||||
"""True at 90%+ — verbose tool calls should be blocked."""
|
||||
return self.tokens_used >= self.critical_threshold
|
||||
|
||||
def should_terminate(self) -> bool:
|
||||
"""True at 95%+ — session should gracefully terminate."""
|
||||
return self.tokens_used >= self.stop_threshold
|
||||
|
||||
def tool_output_budget(self) -> int:
|
||||
"""Max chars allowed for next tool output based on current level."""
|
||||
status = self.check()
|
||||
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
|
||||
|
||||
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
|
||||
"""Truncate tool output to fit budget. Adds truncation notice."""
|
||||
if max_chars is None:
|
||||
max_chars = self.tool_output_budget()
|
||||
|
||||
if len(output) <= max_chars:
|
||||
return output
|
||||
|
||||
# Preserve start and end, truncate middle
|
||||
if max_chars < 200:
|
||||
return output[:max_chars] + "\n[...truncated...]"
|
||||
|
||||
head = max_chars // 2
|
||||
tail = max_chars - head - 30 # reserve for truncation notice
|
||||
truncated = (
|
||||
output[:head]
|
||||
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
|
||||
+ output[-tail:]
|
||||
)
|
||||
return truncated
|
||||
|
||||
def remaining_for_response(self) -> int:
|
||||
"""Tokens available for the model\'s response."""
|
||||
return max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||
|
||||
def growth_rate(self) -> Optional[float]:
|
||||
"""Average token increase per turn (from history)."""
|
||||
if len(self._history) < 2:
|
||||
return None
|
||||
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
|
||||
return sum(diffs) / len(diffs)
|
||||
|
||||
def turns_remaining(self) -> Optional[int]:
|
||||
"""Estimated turns until context is full (based on growth rate)."""
|
||||
rate = self.growth_rate()
|
||||
if rate is None or rate <= 0:
|
||||
return None
|
||||
remaining = self.context_length - self.tokens_used
|
||||
return int(remaining / rate)
|
||||
|
||||
def reset(self):
|
||||
"""Reset budget for new session."""
|
||||
self.tokens_used = 0
|
||||
self.completions_tokens = 0
|
||||
self.total_tool_output_chars = 0
|
||||
self._level = BudgetLevel.NORMAL
|
||||
self._history.clear()
|
||||
|
||||
def _log_transition(self, new_level: BudgetLevel, pct: float):
|
||||
"""Log budget level transitions."""
|
||||
msg = (
|
||||
f"Token budget: {self._level.value} -> {new_level.value} "
|
||||
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
|
||||
)
|
||||
if new_level == BudgetLevel.WARNING:
|
||||
logger.warning(msg)
|
||||
elif new_level == BudgetLevel.CAUTION:
|
||||
logger.warning(msg)
|
||||
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.info(msg)
|
||||
|
||||
def summary(self) -> str:
|
||||
"""Human-readable budget summary."""
|
||||
status = self.check()
|
||||
turns = self.turns_remaining()
|
||||
rate = self.growth_rate()
|
||||
lines = [
|
||||
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
|
||||
f"Level: {status.level.value}",
|
||||
f"Remaining: {status.tokens_remaining:,} tokens",
|
||||
]
|
||||
if rate is not None:
|
||||
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
|
||||
if turns is not None:
|
||||
lines.append(f"Estimated turns left: ~{turns}")
|
||||
if status.message:
|
||||
lines.append(f"Action: {status.message}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# ── Convenience factory ───────────────────────────────────────────────
|
||||
|
||||
def create_budget(context_length: int, **kwargs) -> TokenBudget:
|
||||
"""Create a TokenBudget with defaults."""
|
||||
return TokenBudget(context_length=context_length, **kwargs)
|
||||
156
agent/tool_fixation_detector.py
Normal file
156
agent/tool_fixation_detector.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tool fixation detection — break repetitive tool calling loops.
|
||||
|
||||
Detects when the agent latches onto one tool and calls it repeatedly
|
||||
without making progress. Injects a nudge prompt to break the loop.
|
||||
|
||||
Usage:
|
||||
from agent.tool_fixation_detector import ToolFixationDetector
|
||||
detector = ToolFixationDetector()
|
||||
nudge = detector.record("execute_code")
|
||||
if nudge:
|
||||
# Inject nudge into conversation
|
||||
messages.append({"role": "system", "content": nudge})
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
# Default thresholds
|
||||
_DEFAULT_THRESHOLD = int(os.getenv("TOOL_FIXATION_THRESHOLD", "5"))
|
||||
_DEFAULT_WINDOW = int(os.getenv("TOOL_FIXATION_WINDOW", "10"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FixationEvent:
|
||||
"""Record of a fixation detection."""
|
||||
tool_name: str
|
||||
streak_length: int
|
||||
threshold: int
|
||||
nudge_sent: bool = False
|
||||
|
||||
|
||||
class ToolFixationDetector:
|
||||
"""Detects and breaks tool fixation loops.
|
||||
|
||||
Tracks the sequence of tool calls and detects when the same tool
|
||||
is called N times consecutively. When detected, returns a nudge
|
||||
prompt to inject into the conversation.
|
||||
"""
|
||||
|
||||
def __init__(self, threshold: int = 0, window: int = 0):
|
||||
self.threshold = threshold or _DEFAULT_THRESHOLD
|
||||
self.window = window or _DEFAULT_WINDOW
|
||||
self._history: List[str] = []
|
||||
self._current_streak: str = ""
|
||||
self._streak_count: int = 0
|
||||
self._nudges_sent: int = 0
|
||||
self._events: List[FixationEvent] = []
|
||||
|
||||
@property
|
||||
def nudges_sent(self) -> int:
|
||||
return self._nudges_sent
|
||||
|
||||
@property
|
||||
def events(self) -> List[FixationEvent]:
|
||||
return list(self._events)
|
||||
|
||||
def record(self, tool_name: str) -> Optional[str]:
|
||||
"""Record a tool call and return nudge prompt if fixation detected.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool that was called.
|
||||
|
||||
Returns:
|
||||
Nudge prompt string if fixation detected, None otherwise.
|
||||
"""
|
||||
self._history.append(tool_name)
|
||||
|
||||
# Trim history to window
|
||||
if len(self._history) > self.window:
|
||||
self._history = self._history[-self.window:]
|
||||
|
||||
# Update streak
|
||||
if tool_name == self._current_streak:
|
||||
self._streak_count += 1
|
||||
else:
|
||||
self._current_streak = tool_name
|
||||
self._streak_count = 1
|
||||
|
||||
# Check for fixation
|
||||
if self._streak_count >= self.threshold:
|
||||
event = FixationEvent(
|
||||
tool_name=tool_name,
|
||||
streak_length=self._streak_count,
|
||||
threshold=self.threshold,
|
||||
nudge_sent=True,
|
||||
)
|
||||
self._events.append(event)
|
||||
self._nudges_sent += 1
|
||||
|
||||
return self._build_nudge(tool_name, self._streak_count)
|
||||
|
||||
return None
|
||||
|
||||
def _build_nudge(self, tool_name: str, count: int) -> str:
|
||||
"""Build a nudge prompt to break the fixation loop."""
|
||||
return (
|
||||
f"[SYSTEM: You have called `{tool_name}` {count} times in a row "
|
||||
f"without switching tools. This suggests a fixation loop. "
|
||||
f"Consider:\n"
|
||||
f"1. Is the tool returning an error? Read the error carefully.\n"
|
||||
f"2. Is there a different tool that could help?\n"
|
||||
f"3. Should you ask the user for clarification?\n"
|
||||
f"4. Is the task actually complete?\n"
|
||||
f"Break the loop by trying a different approach.]"
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the detector state."""
|
||||
self._history.clear()
|
||||
self._current_streak = ""
|
||||
self._streak_count = 0
|
||||
|
||||
def get_streak_info(self) -> dict:
|
||||
"""Get current streak information."""
|
||||
return {
|
||||
"current_tool": self._current_streak,
|
||||
"streak_count": self._streak_count,
|
||||
"threshold": self.threshold,
|
||||
"at_threshold": self._streak_count >= self.threshold,
|
||||
"nudges_sent": self._nudges_sent,
|
||||
}
|
||||
|
||||
def format_report(self) -> str:
|
||||
"""Format fixation events as a report."""
|
||||
if not self._events:
|
||||
return "No tool fixation detected."
|
||||
|
||||
lines = [
|
||||
f"Tool Fixation Report ({len(self._events)} events)",
|
||||
"=" * 40,
|
||||
]
|
||||
for e in self._events:
|
||||
lines.append(f" {e.tool_name}: {e.streak_length} consecutive calls (threshold: {e.threshold})")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
# Singleton
|
||||
_detector: Optional[ToolFixationDetector] = None
|
||||
|
||||
|
||||
def get_fixation_detector() -> ToolFixationDetector:
|
||||
"""Get or create the singleton detector."""
|
||||
global _detector
|
||||
if _detector is None:
|
||||
_detector = ToolFixationDetector()
|
||||
return _detector
|
||||
|
||||
|
||||
def reset_fixation_detector() -> None:
|
||||
"""Reset the singleton."""
|
||||
global _detector
|
||||
_detector = None
|
||||
38
docs/cron-audit-890.md
Normal file
38
docs/cron-audit-890.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Cron Job Audit — #890
|
||||
|
||||
## Problem
|
||||
|
||||
9 of 69 cron jobs have zero completions. They waste scheduler cycles.
|
||||
|
||||
## Dead Jobs Identified
|
||||
|
||||
| Job | Schedule | Completions | Action |
|
||||
|-----|----------|-------------|--------|
|
||||
| exp-swarm-pipeline | every 10 min | 0 | disable |
|
||||
| exp-music-generator | every 2h | 0 | disable |
|
||||
| exp-paper-citations | every 3h | 0 | disable |
|
||||
| exp-gbrain-patterns | every 2h | 0 | disable |
|
||||
| exp-infra-hardening | every 2h | 0 | disable |
|
||||
| gemma4-multimodal-burn | every 1h | 0 | disable |
|
||||
| morning-paper-report | daily | 0 | disable |
|
||||
| overnight-collector | every 15 min | 0 | disable |
|
||||
| morning-experiment-report | daily | 0 | disable |
|
||||
|
||||
## Usage
|
||||
|
||||
```bash
|
||||
# Audit: show dead jobs
|
||||
python scripts/cron_audit.py
|
||||
|
||||
# Dry run: preview changes
|
||||
python scripts/cron_audit.py --disable --dry-run
|
||||
|
||||
# Disable dead jobs
|
||||
python scripts/cron_audit.py --disable
|
||||
|
||||
# Delete dead jobs
|
||||
python scripts/cron_audit.py --delete
|
||||
|
||||
# Find jobs with no runs in 7 days
|
||||
python scripts/cron_audit.py --older-than 7d
|
||||
```
|
||||
56
docs/emotional-presence-patterns.md
Normal file
56
docs/emotional-presence-patterns.md
Normal file
@@ -0,0 +1,56 @@
|
||||
# Emotional Presence Patterns for Crisis Support
|
||||
|
||||
Research issue #880. What works for reaching broken men in their darkest moment.
|
||||
|
||||
## Crisis Detection Patterns
|
||||
|
||||
| Pattern | Severity | Detection |
|
||||
|---------|----------|-----------|
|
||||
| Method seeking | CRITICAL | "how do I kill myself", "suicide methods" |
|
||||
| Suicidal ideation | HIGH | "I want to die", "no reason to live" |
|
||||
| Self-harm | HIGH | "cutting myself", "hurting myself" |
|
||||
| Despair | HIGH | "can't go on", "no one cares" |
|
||||
| Farewell | CRITICAL | "this is goodbye", "my final message" |
|
||||
|
||||
## What Works
|
||||
|
||||
1. Validation without agreement: "That sounds incredibly painful"
|
||||
2. Presence over solutions: "I'm here with you right now"
|
||||
3. Specific next steps: "Text HOME to 741741 right now"
|
||||
4. Connection: "You don't have to face this alone"
|
||||
5. Hope without minimizing: "There is a way through this"
|
||||
|
||||
## What Does NOT Work
|
||||
|
||||
1. False empathy: "I feel your pain"
|
||||
2. Toxic positivity: "Look on the bright side!"
|
||||
3. Minimization: "Other people have it worse"
|
||||
4. Guilt induction: "Think about your family"
|
||||
5. Premature problem-solving: "Have you tried therapy?"
|
||||
6. AI dishonesty: pretending to be human
|
||||
|
||||
## SOUL.md Protocol
|
||||
|
||||
1. Ask: "Are you safe right now?" (FIRST response)
|
||||
2. Listen. Stay present. Don't disconnect.
|
||||
3. Never compute the value of a human life
|
||||
4. Never suggest death is an option
|
||||
5. Always point to 988 Lifeline
|
||||
6. Share the gospel: "Jesus saves those who call on His name"
|
||||
|
||||
## Crisis Resources
|
||||
|
||||
- 988 Suicide & Crisis Lifeline: call or text 988
|
||||
- Crisis Text Line: text HOME to 741741
|
||||
- 988 Chat: 988lifeline.org/chat
|
||||
- Spanish: 1-888-628-9454
|
||||
- Emergency: 911
|
||||
|
||||
## Implementation Status
|
||||
|
||||
- Crisis detection: agent/crisis_protocol.py
|
||||
- SHIELD integration: tools/shield/
|
||||
- 988 Lifeline: resources defined
|
||||
- Emotional presence: this document
|
||||
- Escalation tracking: future work
|
||||
- Human notification: future work
|
||||
42
docs/holographic-vector-hybrid.md
Normal file
42
docs/holographic-vector-hybrid.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Holographic + Vector Hybrid Memory Architecture
|
||||
|
||||
Research issue #879. Combining HRR (holographic) and vector (Qdrant) memory.
|
||||
|
||||
## Architecture
|
||||
|
||||
Three memory backends, each with unique strengths:
|
||||
|
||||
| Backend | Strength | Weakness | Use Case |
|
||||
|---------|----------|----------|----------|
|
||||
| FTS5 | Exact keyword match | No semantic understanding | Precise recall |
|
||||
| Vector (Qdrant) | Semantic similarity | No compositional queries | Topic search |
|
||||
| HRR (Holographic) | Compositional queries | Limited scale | Complex reasoning |
|
||||
|
||||
## Why Hybrid
|
||||
|
||||
- FTS5 alone: misses ~30-40% of semantically relevant content
|
||||
- Vector alone: can't do compositional queries ("what did I discuss about X after doing Y?")
|
||||
- HRR alone: unique capability but no semantic fallback
|
||||
- Hybrid: best of all three, RRF fusion for ranking
|
||||
|
||||
## Implementation: Reciprocal Rank Fusion
|
||||
|
||||
Results from each backend are merged using RRF:
|
||||
- score = sum(weight / (k + rank)) for each backend
|
||||
- k=60 (standard RRF constant)
|
||||
- Weights: FTS5=0.6, Vector=0.4 (configurable)
|
||||
|
||||
## Status
|
||||
|
||||
- FTS5: EXISTS (hermes_state.py)
|
||||
- Vector (Qdrant): implemented (tools/hybrid_search.py)
|
||||
- HRR: EXISTS (plugins/memory/holographic.py)
|
||||
- RRF fusion: implemented (tools/hybrid_search.py)
|
||||
- Ingestion pipeline: partial
|
||||
|
||||
## Next Steps
|
||||
|
||||
1. Wire HRR into hybrid_search.py
|
||||
2. Session-level vector ingestion
|
||||
3. Benchmark: measure R@5 improvement
|
||||
4. Cross-session memory persistence
|
||||
24
docs/tool-investigation-report.md
Normal file
24
docs/tool-investigation-report.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Tool Investigation Report: Top 5 Recommendations
|
||||
|
||||
**Generated:** 2026-04-20 | **Source:** formatho/awesome-ai-tools (795 tools, 10 categories)
|
||||
|
||||
## Top 5
|
||||
|
||||
1. **LiteLLM** (76k) — Unified API gateway. Replace custom provider routing. Impact: 5/5, Effort: 2/5
|
||||
2. **Mem0** (53k) — Universal memory layer. Structured long-term memory. Impact: 5/5, Effort: 3/5
|
||||
3. **RAGFlow** (77k) — RAG engine with OCR. Document processing upgrade. Impact: 4/5, Effort: 4/5
|
||||
4. **LiteRT-LM** (3.7k) — On-device inference. Edge/mobile deployment. Impact: 4/5, Effort: 3/5
|
||||
5. **Claude-Mem** (61k) — Session capture and context injection. Impact: 3/5, Effort: 2/5
|
||||
|
||||
## Priority
|
||||
|
||||
- Phase 1: LiteLLM (2-3 days, highest ROI)
|
||||
- Phase 2: Mem0 (1 week, critical for agent maturity)
|
||||
- Phase 3: RAGFlow (1-2 weeks, capability upgrade)
|
||||
|
||||
## Honorable Mentions
|
||||
|
||||
- GPTCache: Semantic cache, 30-50% cost reduction
|
||||
- promptfoo: LLM testing framework
|
||||
- PageIndex: Vectorless RAG
|
||||
- rtk: Token reduction proxy, 60-90% savings
|
||||
@@ -8,6 +8,7 @@ Handles loading and validating configuration for:
|
||||
- Delivery preferences
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
@@ -679,6 +680,26 @@ def load_gateway_config() -> GatewayConfig:
|
||||
return config
|
||||
|
||||
|
||||
def _is_network_accessible(host: str) -> bool:
|
||||
"""Return True if *host* would expose a server beyond the loopback interface.
|
||||
|
||||
Duplicates the logic in ``gateway.platforms.base.is_network_accessible``
|
||||
without creating a circular import (base.py imports from this module).
|
||||
"""
|
||||
try:
|
||||
addr = ipaddress.ip_address(host)
|
||||
if addr.is_loopback:
|
||||
return False
|
||||
# ::ffff:127.x.x.x — Python's is_loopback returns False for
|
||||
# IPv4-mapped loopback; unwrap and check the underlying IPv4.
|
||||
if getattr(addr, "ipv4_mapped", None) and addr.ipv4_mapped.is_loopback:
|
||||
return False
|
||||
return True
|
||||
except ValueError:
|
||||
# Hostname: assume it could be network-accessible.
|
||||
return True
|
||||
|
||||
|
||||
def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
"""Validate and sanitize a loaded GatewayConfig in place.
|
||||
|
||||
@@ -747,6 +768,22 @@ def _validate_gateway_config(config: "GatewayConfig") -> None:
|
||||
)
|
||||
pconfig.enabled = False
|
||||
|
||||
# Warn when the API server is enabled on a network-accessible address
|
||||
# without an auth key. The adapter will refuse to start anyway, but
|
||||
# surfacing this at config-load time lets operators see the problem in
|
||||
# the startup log before any platform adapter initialisation runs.
|
||||
api_cfg = config.platforms.get(Platform.API_SERVER)
|
||||
if api_cfg and api_cfg.enabled:
|
||||
key = api_cfg.extra.get("key", "")
|
||||
host = api_cfg.extra.get("host", "127.0.0.1")
|
||||
if not key and _is_network_accessible(host):
|
||||
logger.warning(
|
||||
"API Server is enabled on %s but API_SERVER_KEY is not set. "
|
||||
"The adapter will refuse to start on a network-accessible address. "
|
||||
"Set API_SERVER_KEY or bind to 127.0.0.1 for local-only access.",
|
||||
host,
|
||||
)
|
||||
|
||||
|
||||
def _apply_env_overrides(config: GatewayConfig) -> None:
|
||||
"""Apply environment variable overrides to config."""
|
||||
|
||||
224
gateway/config_validator.py
Normal file
224
gateway/config_validator.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
Gateway Config Validator & Fallback Fix — #892.
|
||||
|
||||
Validates gateway configuration and provides sensible defaults
|
||||
for missing keys to prevent fallback chain breaks.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigIssue:
|
||||
"""A configuration issue found during validation."""
|
||||
key: str
|
||||
severity: str # error, warning, info
|
||||
message: str
|
||||
fix: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConfigValidation:
|
||||
"""Result of config validation."""
|
||||
valid: bool
|
||||
issues: List[ConfigIssue] = field(default_factory=list)
|
||||
warnings: int = 0
|
||||
errors: int = 0
|
||||
|
||||
|
||||
# Required keys and their defaults
|
||||
REQUIRED_KEYS = {
|
||||
"OPENROUTER_API_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "OPENROUTER_API_KEY not set - fallback chain may break",
|
||||
"fix": "Set OPENROUTER_API_KEY in .env for OpenRouter provider",
|
||||
},
|
||||
"API_SERVER_KEY": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "warning",
|
||||
"message": "API_SERVER_KEY not configured",
|
||||
"fix": "Set API_SERVER_KEY in .env for API server auth",
|
||||
},
|
||||
"GITEA_TOKEN": {
|
||||
"required": False,
|
||||
"default": "",
|
||||
"severity": "info",
|
||||
"message": "GITEA_TOKEN not set - Gitea features disabled",
|
||||
"fix": "Set GITEA_TOKEN in .env for Gitea integration",
|
||||
},
|
||||
}
|
||||
|
||||
# Config validation rules
|
||||
VALIDATION_RULES = [
|
||||
{
|
||||
"key": "idle_minutes",
|
||||
"validate": lambda v: isinstance(v, (int, float)) and v > 0,
|
||||
"message": "Invalid idle_minutes={v} - must be > 0",
|
||||
"fix": "Set idle_minutes to positive integer (default: 30)",
|
||||
},
|
||||
{
|
||||
"key": "max_skills_discord",
|
||||
"validate": lambda v: isinstance(v, int) and v <= 100,
|
||||
"message": "Discord slash command limit reached ({v}/100) - skills not registered",
|
||||
"fix": "Reduce skills or paginate registration",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def validate_config(config: Dict[str, Any]) -> ConfigValidation:
|
||||
"""
|
||||
Validate gateway configuration.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
ConfigValidation with issues found
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# Check required keys
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
value = config.get(key) or os.environ.get(key) or spec["default"]
|
||||
if spec["required"] and not value:
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
elif not value and spec["severity"] != "error":
|
||||
issues.append(ConfigIssue(
|
||||
key=key,
|
||||
severity=spec["severity"],
|
||||
message=spec["message"],
|
||||
fix=spec["fix"],
|
||||
))
|
||||
|
||||
# Check validation rules
|
||||
for rule in VALIDATION_RULES:
|
||||
value = config.get(rule["key"])
|
||||
if value is not None:
|
||||
if not rule["validate"](value):
|
||||
issues.append(ConfigIssue(
|
||||
key=rule["key"],
|
||||
severity="error",
|
||||
message=rule["message"].format(v=value),
|
||||
fix=rule["fix"],
|
||||
))
|
||||
|
||||
errors = sum(1 for i in issues if i.severity == "error")
|
||||
warnings = sum(1 for i in issues if i.severity == "warning")
|
||||
|
||||
return ConfigValidation(
|
||||
valid=errors == 0,
|
||||
issues=issues,
|
||||
warnings=warnings,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
|
||||
def apply_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply default values for missing config keys.
|
||||
|
||||
Args:
|
||||
config: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Config with defaults applied
|
||||
"""
|
||||
result = dict(config)
|
||||
|
||||
for key, spec in REQUIRED_KEYS.items():
|
||||
if key not in result or not result[key]:
|
||||
default = os.environ.get(key) or spec["default"]
|
||||
if default:
|
||||
result[key] = default
|
||||
logger.debug("Applied default for %s", key)
|
||||
|
||||
# Apply validation defaults
|
||||
if "idle_minutes" not in result or not result["idle_minutes"] or result["idle_minutes"] <= 0:
|
||||
result["idle_minutes"] = 30
|
||||
logger.debug("Applied default idle_minutes=30")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fix_discord_skill_limit(skills: List[str], max_skills: int = 95) -> List[str]:
|
||||
"""
|
||||
Fix Discord slash command limit by reducing skills.
|
||||
|
||||
Args:
|
||||
skills: List of skill names
|
||||
max_skills: Maximum skills to register (default 95, leaving room for built-ins)
|
||||
|
||||
Returns:
|
||||
Reduced skill list
|
||||
"""
|
||||
if len(skills) <= max_skills:
|
||||
return skills
|
||||
|
||||
logger.warning(
|
||||
"Discord skill limit: %d skills exceeds %d limit, truncating",
|
||||
len(skills), max_skills
|
||||
)
|
||||
|
||||
# Keep first max_skills (alphabetical priority)
|
||||
return sorted(skills)[:max_skills]
|
||||
|
||||
|
||||
def validate_provider_config(provider: str, config: Dict[str, Any]) -> ConfigIssue:
|
||||
"""
|
||||
Validate provider-specific configuration.
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
config: Provider config
|
||||
|
||||
Returns:
|
||||
ConfigIssue if invalid, None if valid
|
||||
"""
|
||||
if provider == "local-llama.cpp":
|
||||
# Check if llama.cpp is configured
|
||||
if not config.get("model_path") and not config.get("base_url"):
|
||||
return ConfigIssue(
|
||||
key=f"provider.{provider}",
|
||||
severity="warning",
|
||||
message=f"{provider} provider not configured - fallback fails",
|
||||
fix=f"Configure {provider} model_path or base_url, or remove from provider list",
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def format_validation_report(validation: ConfigValidation) -> str:
|
||||
"""Format validation results as a report."""
|
||||
lines = [
|
||||
"=" * 50,
|
||||
"GATEWAY CONFIG VALIDATION",
|
||||
"=" * 50,
|
||||
"",
|
||||
f"Status: {'VALID' if validation.valid else 'INVALID'}",
|
||||
f"Errors: {validation.errors}",
|
||||
f"Warnings: {validation.warnings}",
|
||||
"",
|
||||
]
|
||||
|
||||
if validation.issues:
|
||||
lines.append("Issues:")
|
||||
for issue in validation.issues:
|
||||
icon = "❌" if issue.severity == "error" else "⚠️" if issue.severity == "warning" else "ℹ️"
|
||||
lines.append(f" {icon} [{issue.key}] {issue.message}")
|
||||
lines.append(f" Fix: {issue.fix}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
@@ -27,7 +27,9 @@ import threading
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
from tools.registry import discover_builtin_tools, registry
|
||||
from tools.poka_yoke import validate_tool_call
|
||||
from tools.tool_pokayoke import validate_tool_call, reset_circuit_breaker, get_hallucination_stats
|
||||
from tools.hardcoded_path_guard import guard_tool_dispatch as _guard_hardcoded_paths
|
||||
from toolsets import resolve_toolset, validate_toolset
|
||||
from agent.tool_orchestrator import orchestrator
|
||||
|
||||
@@ -501,21 +503,14 @@ def handle_function_call(
|
||||
# Prefer the caller-provided list so subagents can't overwrite
|
||||
# the parent's tool set via the process-global.
|
||||
sandbox_enabled = enabled_tools if enabled_tools is not None else _last_resolved_tool_names
|
||||
# Poka-yoke #921: guard against hardcoded home-directory paths
|
||||
_hardcoded_err = _guard_hardcoded_paths(function_name, function_args)
|
||||
if _hardcoded_err:
|
||||
logger.warning(f"Hardcoded path blocked: {function_name}")
|
||||
return _hardcoded_err
|
||||
|
||||
# Poka-yoke: validate tool call before dispatch
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
# Return structured error with suggestions
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
||||
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
||||
if corrected_name:
|
||||
function_name = corrected_name
|
||||
if corrected_params:
|
||||
function_args = corrected_params
|
||||
if pokayoke_messages:
|
||||
logger.info(f"Poka-yoke: {pokayoke_messages}")
|
||||
# Poka-yoke: validate tool call before dispatch (else branch)
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
# Return structured error with suggestions
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
@@ -533,6 +528,16 @@ def handle_function_call(
|
||||
enabled_tools=sandbox_enabled,
|
||||
)
|
||||
else:
|
||||
# Poka-yoke: validate tool call before dispatch
|
||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||
if not is_valid:
|
||||
error_msg = "\n".join(pokayoke_messages)
|
||||
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
||||
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
||||
if corrected_name:
|
||||
function_name = corrected_name
|
||||
if corrected_params:
|
||||
function_args = corrected_params
|
||||
result = orchestrator.dispatch(
|
||||
function_name, function_args,
|
||||
task_id=task_id,
|
||||
|
||||
68
research_awesome_ai_tools_top5.md
Normal file
68
research_awesome_ai_tools_top5.md
Normal file
@@ -0,0 +1,68 @@
|
||||
# Tool Investigation Report: Top 5 Recommendations from awesome-ai-tools
|
||||
|
||||
**Generated:** 2026-04-20 | **Source:** [formatho/awesome-ai-tools](https://github.com/formatho/awesome-ai-tools)
|
||||
|
||||
---
|
||||
|
||||
## Methodology
|
||||
|
||||
Scanned 795 tools across 10 categories from the awesome-ai-tools repository. Evaluated each tool against Hermes Agent's architecture and needs:
|
||||
- **Memory/Context**: Persistent memory, conversation history, knowledge graphs
|
||||
- **Inference Optimization**: Token efficiency, local model serving, routing
|
||||
- **Agent Orchestration**: Multi-agent coordination, fleet management
|
||||
- **Workflow Automation**: Task decomposition, scheduling, pipelines
|
||||
- **Retrieval/RAG**: Semantic search, document understanding, context injection
|
||||
|
||||
Each tool scored on: GitHub stars, development activity (freshness), integration potential, and impact on Hermes.
|
||||
|
||||
---
|
||||
|
||||
## Top 5 Recommended Tools
|
||||
|
||||
| Rank | Tool | Stars | Category | Integration Effort | Impact | Why It Fits Hermes |
|
||||
|------|------|-------|----------|-------------------|--------|---------------------|
|
||||
| 1 | **[LiteLLM](https://github.com/BerriAI/litellm)** | 76k+ | Inference Optimization | 2/5 | 5/5 | Unified API gateway for 100+ LLM providers with cost tracking, guardrails, load balancing, and logging. Hermes already routes through multiple providers — LiteLLM could replace custom provider routing with battle-tested load balancing and automatic fallback. Direct drop-in for `provider` abstraction layer. Native support for Bedrock, Azure, OpenAI, VertexAI, Anthropic, Ollama, vLLM. Would reduce Hermes's provider management code by ~60%. |
|
||||
| 2 | **[Mem0](https://github.com/mem0ai/mem0)** | 53k+ | Memory/Context | 3/5 | 5/5 | Universal memory layer for AI agents with persistent, searchable memory across sessions. Hermes has session memory but lacks a structured long-term memory system. Mem0 provides automatic memory extraction from conversations, semantic search over memories, and memory decay/pruning. Could replace/enhance the current memory tool with a purpose-built agent memory infrastructure. Supports Pinecone, Qdrant, ChromaDB backends. |
|
||||
| 3 | **[RAGFlow](https://github.com/infiniflow/ragflow)** | 77k+ | Retrieval/RAG | 4/5 | 4/5 | Open-source RAG engine with deep document understanding, OCR, and agent capabilities. Hermes's current retrieval is limited to web search and file reading. RAGFlow adds visual document parsing (PDF/Word/PPT with tables, charts, formulas), chunk-level citation, and configurable retrieval strategies. Would massively upgrade Hermes's document processing capabilities. Docker-deployable, compatible with local models. |
|
||||
| 4 | **[LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM)** | 3.7k | Inference Optimization | 3/5 | 4/5 | C++ implementation of Google's LiteRT for efficient on-device language model inference. Hermes supports local models via Ollama but lacks optimized on-device inference for edge/mobile. LiteRT-LM provides sub-second inference on commodity hardware with minimal memory footprint. Could power a "Hermes lite" mode for offline/edge deployments. Active development (Fresh status), backed by Google AI Edge team. |
|
||||
| 5 | **[Claude-Mem](https://github.com/thedotmack/claude-mem)** | 61k+ | Memory/Context | 2/5 | 3/5 | Automatic session capture and context injection for coding agents. Compresses session history with AI and injects relevant context into future sessions. Pattern directly applicable to Hermes's cross-session persistence problem. Uses agent SDK for intelligent compression — could enhance Hermes's session_search with automatic relevance-weighted recall. Lightweight integration, focused on the exact pain point of context loss between sessions. |
|
||||
|
||||
---
|
||||
|
||||
## Category Coverage Analysis
|
||||
|
||||
| Category | Tools Scanned | Top Pick | Coverage Gap |
|
||||
|----------|--------------|----------|-------------|
|
||||
| Memory/Context | 45+ | Mem0 (53k⭐) | Hermes lacks structured long-term memory — Mem0 or Claude-Mem would fill this |
|
||||
| Inference Optimization | 80+ | LiteLLM (76k⭐) | Provider routing is custom-built; LiteLLM standardizes it |
|
||||
| Agent Orchestration | 120+ | langgraph (29k⭐) | Hermes's fleet model is unique — langgraph patterns could improve DAG workflows |
|
||||
| Workflow Automation | 90+ | n8n (183k⭐) | Cron system exists but n8n patterns could improve visual pipeline design |
|
||||
| Retrieval/RAG | 60+ | RAGFlow (77k⭐) | Document processing is weak; RAGFlow adds OCR + visual parsing |
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
**Phase 1 (Immediate):** LiteLLM integration — highest impact, lowest effort. Replace custom provider routing with LiteLLM's unified API. Estimated: 2-3 days.
|
||||
|
||||
**Phase 2 (Short-term):** Mem0 memory layer — critical for agent maturity. Add structured memory extraction and retrieval. Estimated: 1 week.
|
||||
|
||||
**Phase 3 (Medium-term):** RAGFlow document engine — significant capability upgrade. Requires Docker setup and integration with existing file tools. Estimated: 1-2 weeks.
|
||||
|
||||
---
|
||||
|
||||
## Honorable Mentions
|
||||
|
||||
- **[GPTCache](https://github.com/zilliztech/GPTCache)** (8k⭐): Semantic cache for LLMs — could reduce API costs by 30-50% for repeated queries
|
||||
- **[promptfoo](https://github.com/promptfoo/promptfoo)** (20k⭐): LLM testing/evaluation framework — essential for quality assurance
|
||||
- **[PageIndex](https://github.com/VectifyAI/PageIndex)** (25k⭐): Vectorless reasoning-based RAG — next-gen retrieval without embeddings
|
||||
- **[rtk](https://github.com/rtk-ai/rtk)** (28k⭐): CLI proxy that reduces token consumption 60-90% — directly relevant to cost optimization
|
||||
|
||||
---
|
||||
|
||||
## Data Sources
|
||||
|
||||
- Repository: https://github.com/formatho/awesome-ai-tools
|
||||
- Total tools cataloged: 795
|
||||
- Categories analyzed: Agents & Automation, Developer Tools, LLMs & Chatbots, Research & Data, Productivity
|
||||
- Freshness filter: Prioritized tools with Fresh (≤7d) or Recent (≤30d) status
|
||||
181
scripts/cron_audit.py
Normal file
181
scripts/cron_audit.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
cron-audit — Audit and clean up dead cron jobs.
|
||||
|
||||
Finds jobs with zero completions, low success rates, or stale schedules.
|
||||
Can disable or delete dead jobs.
|
||||
|
||||
Usage:
|
||||
python scripts/cron_audit.py # Show dead jobs
|
||||
python scripts/cron_audit.py --disable # Disable dead jobs
|
||||
python scripts/cron_audit.py --delete # Delete dead jobs
|
||||
python scripts/cron_audit.py --threshold 0 # Jobs with 0 completions
|
||||
python scripts/cron_audit.py --older-than 7d # Jobs with no runs in 7 days
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
JOBS_FILE = HERMES_HOME / "cron" / "jobs.json"
|
||||
|
||||
|
||||
def load_jobs() -> List[Dict[str, Any]]:
|
||||
"""Load cron jobs from jobs.json."""
|
||||
if not JOBS_FILE.exists():
|
||||
print(f"Error: {JOBS_FILE} not found")
|
||||
return []
|
||||
with open(JOBS_FILE) as f:
|
||||
data = json.load(f)
|
||||
return data.get("jobs", [])
|
||||
|
||||
|
||||
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||
"""Save jobs back to jobs.json."""
|
||||
JOBS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(JOBS_FILE, "r") as f:
|
||||
data = json.load(f)
|
||||
data["jobs"] = jobs
|
||||
with open(JOBS_FILE, "w") as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
|
||||
def find_dead_jobs(
|
||||
jobs: List[Dict[str, Any]],
|
||||
completion_threshold: int = 0,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find jobs with completions at or below threshold."""
|
||||
dead = []
|
||||
for job in jobs:
|
||||
repeat = job.get("repeat", {})
|
||||
completed = repeat.get("completed", 0)
|
||||
if completed <= completion_threshold:
|
||||
dead.append(job)
|
||||
return dead
|
||||
|
||||
|
||||
def find_stale_jobs(
|
||||
jobs: List[Dict[str, Any]],
|
||||
max_age_hours: float = 168, # 7 days
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Find jobs that haven't run in max_age_hours."""
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
stale = []
|
||||
now = time.time()
|
||||
|
||||
for job in jobs:
|
||||
last_run = job.get("last_run_at")
|
||||
if not last_run:
|
||||
# Never ran — check creation time
|
||||
created = job.get("created_at")
|
||||
if created:
|
||||
try:
|
||||
dt = datetime.fromisoformat(created.replace("Z", "+00:00"))
|
||||
age_hours = (now - dt.timestamp()) / 3600
|
||||
if age_hours > max_age_hours:
|
||||
stale.append(job)
|
||||
except Exception:
|
||||
stale.append(job)
|
||||
else:
|
||||
stale.append(job)
|
||||
else:
|
||||
try:
|
||||
dt = datetime.fromisoformat(last_run.replace("Z", "+00:00"))
|
||||
age_hours = (now - dt.timestamp()) / 3600
|
||||
if age_hours > max_age_hours:
|
||||
stale.append(job)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return stale
|
||||
|
||||
|
||||
def format_job(job: Dict[str, Any]) -> str:
|
||||
"""Format a job for display."""
|
||||
name = job.get("name", job.get("id", "?"))
|
||||
schedule = job.get("schedule_display", "?")
|
||||
repeat = job.get("repeat", {})
|
||||
completed = repeat.get("completed", 0)
|
||||
times = repeat.get("times")
|
||||
enabled = job.get("enabled", True)
|
||||
state = job.get("state", "unknown")
|
||||
last_run = job.get("last_run_at", "never")
|
||||
|
||||
status = "enabled" if enabled else "disabled"
|
||||
if state == "paused":
|
||||
status = "paused"
|
||||
|
||||
repeat_str = f"{completed}/{times}" if times else f"{completed}/∞"
|
||||
|
||||
return f" {name:40s} | {schedule:20s} | done: {repeat_str:8s} | {status}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Audit and clean up dead cron jobs")
|
||||
parser.add_argument("--disable", action="store_true", help="Disable dead jobs")
|
||||
parser.add_argument("--delete", action="store_true", help="Delete dead jobs")
|
||||
parser.add_argument("--threshold", type=int, default=0, help="Completion threshold (default: 0)")
|
||||
parser.add_argument("--older-than", type=str, help="Find jobs with no runs in N days (e.g., 7d)")
|
||||
parser.add_argument("--dry-run", action="store_true", help="Show what would change")
|
||||
args = parser.parse_args()
|
||||
|
||||
jobs = load_jobs()
|
||||
if not jobs:
|
||||
print("No jobs found.")
|
||||
return
|
||||
|
||||
print(f"Total jobs: {len(jobs)}")
|
||||
|
||||
# Find dead jobs
|
||||
dead = find_dead_jobs(jobs, args.threshold)
|
||||
print(f"Jobs with <= {args.threshold} completions: {len(dead)}")
|
||||
|
||||
if args.older_than:
|
||||
days = int(args.older_than.rstrip("d"))
|
||||
stale = find_stale_jobs(jobs, max_age_hours=days * 24)
|
||||
print(f"Jobs with no runs in {days} days: {len(stale)}")
|
||||
dead = list({j["id"]: j for j in dead + stale}.values())
|
||||
|
||||
if not dead:
|
||||
print("No dead jobs found.")
|
||||
return
|
||||
|
||||
print(f"\nDead jobs ({len(dead)}):")
|
||||
for job in dead:
|
||||
print(format_job(job))
|
||||
|
||||
if args.disable:
|
||||
if args.dry_run:
|
||||
print(f"\nDRY RUN: Would disable {len(dead)} jobs")
|
||||
return
|
||||
|
||||
job_ids = {j["id"] for j in dead}
|
||||
for job in jobs:
|
||||
if job["id"] in job_ids:
|
||||
job["enabled"] = False
|
||||
job["state"] = "disabled"
|
||||
|
||||
save_jobs(jobs)
|
||||
print(f"\nDisabled {len(dead)} jobs.")
|
||||
|
||||
elif args.delete:
|
||||
if args.dry_run:
|
||||
print(f"\nDRY RUN: Would delete {len(dead)} jobs")
|
||||
return
|
||||
|
||||
job_ids = {j["id"] for j in dead}
|
||||
jobs = [j for j in jobs if j["id"] not in job_ids]
|
||||
save_jobs(jobs)
|
||||
print(f"\nDeleted {len(dead)} jobs.")
|
||||
|
||||
else:
|
||||
print(f"\nUse --disable or --delete to take action. Add --dry-run to preview.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
147
scripts/queue_health_check.py
Executable file
147
scripts/queue_health_check.py
Executable file
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Queue Health Check — Verify dispatch queue is operational.
|
||||
|
||||
Checks:
|
||||
1. Queue file exists and is readable
|
||||
2. Queue has pending items
|
||||
3. Queue is not stuck (items processing)
|
||||
4. Queue age (stale items)
|
||||
|
||||
Usage:
|
||||
python scripts/queue_health_check.py
|
||||
python scripts/queue_health_check.py --json
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def check_queue_health(queue_path: str = "~/.hermes/queue.json") -> dict:
|
||||
"""Check queue health status."""
|
||||
path = Path(queue_path).expanduser()
|
||||
|
||||
result = {
|
||||
"healthy": True,
|
||||
"checks": {},
|
||||
"warnings": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Check 1: File exists
|
||||
if not path.exists():
|
||||
result["healthy"] = False
|
||||
result["errors"].append(f"Queue file not found: {path}")
|
||||
result["checks"]["file_exists"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["file_exists"] = True
|
||||
|
||||
# Check 2: File is readable
|
||||
try:
|
||||
with open(path, "r") as f:
|
||||
data = json.load(f)
|
||||
except Exception as e:
|
||||
result["healthy"] = False
|
||||
result["errors"].append(f"Cannot read queue: {e}")
|
||||
result["checks"]["readable"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["readable"] = True
|
||||
|
||||
# Check 3: Queue structure
|
||||
if not isinstance(data, dict):
|
||||
result["healthy"] = False
|
||||
result["errors"].append("Queue is not a dict")
|
||||
result["checks"]["valid_structure"] = False
|
||||
return result
|
||||
|
||||
result["checks"]["valid_structure"] = True
|
||||
|
||||
# Check 4: Pending items
|
||||
pending = data.get("pending", [])
|
||||
processing = data.get("processing", [])
|
||||
completed = data.get("completed", [])
|
||||
|
||||
result["checks"]["pending_count"] = len(pending)
|
||||
result["checks"]["processing_count"] = len(processing)
|
||||
result["checks"]["completed_count"] = len(completed)
|
||||
|
||||
if len(pending) == 0 and len(processing) == 0:
|
||||
result["warnings"].append("Queue is empty")
|
||||
|
||||
# Check 5: Stale processing items
|
||||
now = datetime.now()
|
||||
stale_threshold = timedelta(hours=1)
|
||||
|
||||
for item in processing:
|
||||
started = item.get("started_at")
|
||||
if started:
|
||||
try:
|
||||
started_time = datetime.fromisoformat(started.replace("Z", "+00:00"))
|
||||
if now - started_time > stale_threshold:
|
||||
result["warnings"].append(f"Stale item: {item.get('id', 'unknown')} (started {started})")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check 6: Queue age
|
||||
if pending:
|
||||
oldest = min(pending, key=lambda x: x.get("added_at", ""))
|
||||
added = oldest.get("added_at")
|
||||
if added:
|
||||
try:
|
||||
added_time = datetime.fromisoformat(added.replace("Z", "+00:00"))
|
||||
age = now - added_time
|
||||
if age > timedelta(hours=24):
|
||||
result["warnings"].append(f"Old item in queue: {oldest.get('id', 'unknown')} (added {added})")
|
||||
except:
|
||||
pass
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Queue health check")
|
||||
parser.add_argument("--queue", default="~/.hermes/queue.json", help="Queue file path")
|
||||
parser.add_argument("--json", action="store_true", help="Output as JSON")
|
||||
args = parser.parse_args()
|
||||
|
||||
result = check_queue_health(args.queue)
|
||||
|
||||
if args.json:
|
||||
print(json.dumps(result, indent=2))
|
||||
else:
|
||||
print("Queue Health Check")
|
||||
print("=" * 50)
|
||||
print(f"Healthy: {'✓' if result['healthy'] else '✗'}")
|
||||
print()
|
||||
|
||||
print("Checks:")
|
||||
for check, value in result["checks"].items():
|
||||
if isinstance(value, bool):
|
||||
print(f" {check}: {'✓' if value else '✗'}")
|
||||
else:
|
||||
print(f" {check}: {value}")
|
||||
|
||||
if result["warnings"]:
|
||||
print()
|
||||
print("Warnings:")
|
||||
for warning in result["warnings"]:
|
||||
print(f" ⚠ {warning}")
|
||||
|
||||
if result["errors"]:
|
||||
print()
|
||||
print("Errors:")
|
||||
for error in result["errors"]:
|
||||
print(f" ✗ {error}")
|
||||
|
||||
sys.exit(0 if result["healthy"] else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
145
scripts/time-aware-model-router.py
Normal file
145
scripts/time-aware-model-router.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
time-aware-model-router.py — Route cron jobs to better models during high-error hours.
|
||||
|
||||
Empirical finding (audit 2026-04-12): Error rate peaks at 18:00 (9.4%) during
|
||||
evening cron batches vs 4.0% at 09:00 during interactive work.
|
||||
|
||||
This script provides a model resolver that selects a more capable model during
|
||||
high-error hours (17:00-22:00) and the default model otherwise.
|
||||
|
||||
Usage:
|
||||
# As a standalone resolver
|
||||
python3 scripts/time-aware-model-router.py
|
||||
# Returns: {"provider": "nous", "model": "xiaomi/mimo-v2-pro"}
|
||||
|
||||
# With hour override for testing
|
||||
python3 scripts/time-aware-model-router.py --hour 18
|
||||
# Returns: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||
|
||||
# As a cron job wrapper
|
||||
python3 scripts/time-aware-model-router.py --wrap -- prompt goes here
|
||||
|
||||
Environment variables:
|
||||
HERMES_DEFAULT_PROVIDER: Default provider for normal hours (default: nous)
|
||||
HERMES_DEFAULT_MODEL: Default model for normal hours (default: xiaomi/mimo-v2-pro)
|
||||
HERMES_PEAK_PROVIDER: Provider for high-error hours (default: openrouter)
|
||||
HERMES_PEAK_MODEL: Model for high-error hours (default: anthropic/claude-sonnet-4)
|
||||
HERMES_PEAK_HOURS: Comma-separated hours for peak routing (default: 17,18,19,20,21,22)
|
||||
|
||||
Refs: hermes-agent#889
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# ── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
DEFAULT_PROVIDER = os.environ.get("HERMES_DEFAULT_PROVIDER", "nous")
|
||||
DEFAULT_MODEL = os.environ.get("HERMES_DEFAULT_MODEL", "xiaomi/mimo-v2-pro")
|
||||
PEAK_PROVIDER = os.environ.get("HERMES_PEAK_PROVIDER", "openrouter")
|
||||
PEAK_MODEL = os.environ.get("HERMES_PEAK_MODEL", "anthropic/claude-sonnet-4")
|
||||
PEAK_HOURS = set(int(h) for h in os.environ.get("HERMES_PEAK_HOURS", "17,18,19,20,21,22").split(","))
|
||||
|
||||
# ── Time-aware routing ─────────────────────────────────────────────────────
|
||||
|
||||
def get_current_hour():
|
||||
"""Get the current local hour (0-23)."""
|
||||
return datetime.now().hour
|
||||
|
||||
|
||||
def is_peak_hour(hour=None):
|
||||
"""Check if the given hour (or current hour) is a high-error period."""
|
||||
if hour is None:
|
||||
hour = get_current_hour()
|
||||
return hour in PEAK_HOURS
|
||||
|
||||
|
||||
def resolve_model(hour=None):
|
||||
"""
|
||||
Resolve which model to use based on time of day.
|
||||
|
||||
Returns dict with 'provider' and 'model' keys.
|
||||
During peak hours (high error rate), uses a more capable model.
|
||||
During normal hours, uses the default model.
|
||||
"""
|
||||
if is_peak_hour(hour):
|
||||
return {
|
||||
"provider": PEAK_PROVIDER,
|
||||
"model": PEAK_MODEL,
|
||||
"reason": f"peak_hour ({hour if hour is not None else get_current_hour()}:00)",
|
||||
"confidence_note": "Using stronger model during high-error period"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"provider": DEFAULT_PROVIDER,
|
||||
"model": DEFAULT_MODEL,
|
||||
"reason": "normal_hour",
|
||||
"confidence_note": "Default model sufficient during low-error period"
|
||||
}
|
||||
|
||||
|
||||
def get_routing_info():
|
||||
"""Get full routing info including current state and config."""
|
||||
hour = get_current_hour()
|
||||
resolved = resolve_model(hour)
|
||||
return {
|
||||
"current_hour": hour,
|
||||
"is_peak": is_peak_hour(hour),
|
||||
"peak_hours": sorted(PEAK_HOURS),
|
||||
"routing": resolved,
|
||||
"config": {
|
||||
"default": {"provider": DEFAULT_PROVIDER, "model": DEFAULT_MODEL},
|
||||
"peak": {"provider": PEAK_PROVIDER, "model": PEAK_MODEL},
|
||||
},
|
||||
"source": "hermes-agent#889 — empirical audit 2026-04-12",
|
||||
}
|
||||
|
||||
|
||||
# ── CLI ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def main():
|
||||
args = sys.argv[1:]
|
||||
|
||||
# Parse --hour
|
||||
hour = None
|
||||
if "--hour" in args:
|
||||
idx = args.index("--hour")
|
||||
if idx + 1 < len(args):
|
||||
hour = int(args[idx + 1])
|
||||
|
||||
# Parse --wrap mode
|
||||
if "--wrap" in args:
|
||||
# Run the remaining args as a command with model override
|
||||
resolved = resolve_model(hour)
|
||||
wrap_idx = args.index("--wrap")
|
||||
cmd_parts = args[wrap_idx + 1:]
|
||||
|
||||
# Inject model/provider into environment
|
||||
env = os.environ.copy()
|
||||
env["HERMES_MODEL"] = resolved["model"]
|
||||
env["HERMES_PROVIDER"] = resolved["provider"]
|
||||
|
||||
if cmd_parts:
|
||||
import subprocess
|
||||
result = subprocess.run(cmd_parts, env=env)
|
||||
sys.exit(result.returncode)
|
||||
else:
|
||||
print(json.dumps(resolved, indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Parse --info mode
|
||||
if "--info" in args:
|
||||
print(json.dumps(get_routing_info(), indent=2))
|
||||
sys.exit(0)
|
||||
|
||||
# Default: output resolved model as JSON
|
||||
resolved = resolve_model(hour)
|
||||
print(json.dumps(resolved, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -10,6 +10,7 @@ from gateway.config import (
|
||||
PlatformConfig,
|
||||
SessionResetPolicy,
|
||||
_apply_env_overrides,
|
||||
_validate_gateway_config,
|
||||
load_gateway_config,
|
||||
)
|
||||
|
||||
@@ -294,3 +295,151 @@ class TestHomeChannelEnvOverrides:
|
||||
home = config.platforms[platform].home_channel
|
||||
assert home is not None, f"{platform.value}: home_channel should not be None"
|
||||
assert (home.chat_id, home.name) == expected, platform.value
|
||||
|
||||
|
||||
class TestValidateGatewayConfig:
|
||||
"""Tests for _validate_gateway_config — in-place sanitisation of loaded config."""
|
||||
|
||||
# -- idle_minutes validation --
|
||||
|
||||
def test_idle_minutes_zero_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 0
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = -60
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_idle_minutes_none_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = None # type: ignore[assignment]
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 1440
|
||||
|
||||
def test_valid_idle_minutes_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.idle_minutes = 90
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.idle_minutes == 90
|
||||
|
||||
# -- at_hour validation --
|
||||
|
||||
def test_at_hour_too_high_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 24
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_at_hour_negative_is_corrected_to_default(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = -1
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 4
|
||||
|
||||
def test_valid_at_hour_is_unchanged(self):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = 3
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == 3
|
||||
|
||||
def test_at_hour_boundary_values_are_valid(self):
|
||||
for valid_hour in (0, 23):
|
||||
config = GatewayConfig()
|
||||
config.default_reset_policy.at_hour = valid_hour
|
||||
_validate_gateway_config(config)
|
||||
assert config.default_reset_policy.at_hour == valid_hour
|
||||
|
||||
# -- empty-token warning (enabled platforms) --
|
||||
|
||||
def test_empty_string_token_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=True, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"TELEGRAM_BOT_TOKEN" in r.message and "empty" in r.message
|
||||
for r in caplog.records
|
||||
)
|
||||
|
||||
def test_disabled_platform_with_empty_token_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.TELEGRAM: PlatformConfig(enabled=False, token=""),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any("TELEGRAM_BOT_TOKEN" in r.message for r in caplog.records)
|
||||
|
||||
# -- API Server key / binding warnings --
|
||||
|
||||
def test_api_server_network_binding_without_key_logs_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_loopback_without_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "127.0.0.1"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_network_binding_with_key_no_warning(self, caplog):
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(
|
||||
enabled=True,
|
||||
extra={"host": "0.0.0.0", "key": "sk-real-key-here"},
|
||||
),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
def test_api_server_default_loopback_without_key_no_warning(self, caplog):
|
||||
"""API server with no explicit host defaults to 127.0.0.1 — no warning."""
|
||||
import logging
|
||||
config = GatewayConfig(
|
||||
platforms={
|
||||
Platform.API_SERVER: PlatformConfig(enabled=True),
|
||||
}
|
||||
)
|
||||
with caplog.at_level(logging.WARNING, logger="gateway.config"):
|
||||
_validate_gateway_config(config)
|
||||
assert not any(
|
||||
"API_SERVER_KEY" in r.message for r in caplog.records
|
||||
)
|
||||
|
||||
97
tests/test_circuit_breaker.py
Normal file
97
tests/test_circuit_breaker.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""Tests for circuit breaker (#885)."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from agent.circuit_breaker import CircuitBreaker, ToolCircuitBreaker, MultiToolCircuitBreaker, CircuitState
|
||||
|
||||
|
||||
def test_closed_allows_execution():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
assert cb.can_execute()
|
||||
|
||||
|
||||
def test_opens_after_threshold():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.can_execute() # Still closed at 2
|
||||
cb.record_result(False)
|
||||
assert not cb.can_execute() # Open at 3
|
||||
|
||||
|
||||
def test_closes_on_success():
|
||||
cb = CircuitBreaker(failure_threshold=3)
|
||||
cb.record_result(False)
|
||||
cb.record_result(True)
|
||||
assert cb.consecutive_failures == 0
|
||||
|
||||
|
||||
def test_half_open_recovery():
|
||||
cb = CircuitBreaker(failure_threshold=2, recovery_timeout=0.1, success_threshold=1)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
|
||||
import time
|
||||
time.sleep(0.15)
|
||||
|
||||
assert cb.can_execute() # Moved to half-open
|
||||
cb.record_result(True)
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
def test_recovery_action_streak():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(5):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "switch_tool_type"
|
||||
|
||||
|
||||
def test_recovery_action_critical():
|
||||
cb = ToolCircuitBreaker(failure_threshold=3)
|
||||
for _ in range(10):
|
||||
cb.record_result(False)
|
||||
action = cb.get_recovery_action()
|
||||
assert action["action"] == "terminal_only"
|
||||
assert action["severity"] == "critical"
|
||||
|
||||
|
||||
def test_multi_tool_breaker():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
mcb.record_result("read_file", False)
|
||||
assert not mcb.can_execute("read_file")
|
||||
assert mcb.can_execute("terminal") # Different tool unaffected
|
||||
|
||||
|
||||
def test_global_state():
|
||||
mcb = MultiToolCircuitBreaker()
|
||||
mcb.record_result("tool_a", False)
|
||||
mcb.record_result("tool_b", False)
|
||||
state = mcb.get_global_state()
|
||||
assert state["global_streak"] == 2
|
||||
|
||||
|
||||
def test_reset():
|
||||
cb = CircuitBreaker(failure_threshold=2)
|
||||
cb.record_result(False)
|
||||
cb.record_result(False)
|
||||
assert cb.state == CircuitState.OPEN
|
||||
cb.reset()
|
||||
assert cb.state == CircuitState.CLOSED
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_closed_allows_execution, test_opens_after_threshold,
|
||||
test_closes_on_success, test_half_open_recovery,
|
||||
test_recovery_action_streak, test_recovery_action_critical,
|
||||
test_multi_tool_breaker, test_global_state, test_reset]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
127
tests/test_context_budget.py
Normal file
127
tests/test_context_budget.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Tests for context budget tracker
|
||||
|
||||
Issue: #838
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent.context_budget import (
|
||||
ContextBudget,
|
||||
ContextBudgetTracker,
|
||||
estimate_tokens,
|
||||
estimate_messages_tokens,
|
||||
check_context_budget,
|
||||
preflight_token_check,
|
||||
THRESHOLD_WARNING,
|
||||
THRESHOLD_CRITICAL,
|
||||
THRESHOLD_DANGER,
|
||||
)
|
||||
|
||||
|
||||
class TestContextBudget(unittest.TestCase):
|
||||
|
||||
def test_basic_budget(self):
|
||||
b = ContextBudget(context_limit=10000)
|
||||
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
|
||||
self.assertEqual(b.remaining, 8000)
|
||||
self.assertEqual(b.utilization, 0.0)
|
||||
|
||||
def test_utilization(self):
|
||||
b = ContextBudget(context_limit=10000, used_tokens=4000)
|
||||
self.assertEqual(b.utilization, 0.5)
|
||||
self.assertEqual(b.remaining, 4000)
|
||||
|
||||
|
||||
class TestTokenEstimation(unittest.TestCase):
|
||||
|
||||
def test_estimate_tokens(self):
|
||||
self.assertEqual(estimate_tokens(""), 0)
|
||||
self.assertEqual(estimate_tokens("a" * 4), 1)
|
||||
self.assertEqual(estimate_tokens("a" * 400), 100)
|
||||
|
||||
def test_estimate_messages(self):
|
||||
messages = [
|
||||
{"role": "user", "content": "a" * 400},
|
||||
{"role": "assistant", "content": "b" * 800},
|
||||
]
|
||||
tokens = estimate_messages_tokens(messages)
|
||||
self.assertEqual(tokens, 300) # 100 + 200
|
||||
|
||||
|
||||
class TestContextBudgetTracker(unittest.TestCase):
|
||||
|
||||
def test_warning_at_70_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5600 # 70% of 8000 available
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("70", warning)
|
||||
|
||||
def test_critical_at_85_percent(self):
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
|
||||
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
|
||||
tracker.budget.used_tokens = 6800 # 85% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("85", warning)
|
||||
|
||||
def test_danger_at_95_percent(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600 # 95% of 8000
|
||||
warning = tracker.get_warning()
|
||||
self.assertIsNotNone(warning)
|
||||
self.assertIn("CRITICAL", warning)
|
||||
|
||||
def test_can_fit(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
self.assertTrue(tracker.can_fit(1000))
|
||||
self.assertFalse(tracker.can_fit(5000))
|
||||
|
||||
def test_preflight_check(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 5000
|
||||
|
||||
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
|
||||
self.assertTrue(can_fit)
|
||||
self.assertEqual(msg, "")
|
||||
|
||||
|
||||
class TestCheckContextBudget(unittest.TestCase):
|
||||
|
||||
def test_no_warning_under_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
messages = [{"role": "user", "content": "short"}]
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNone(warning)
|
||||
|
||||
def test_warning_over_threshold(self):
|
||||
with patch("agent.context_budget._tracker", None):
|
||||
# Create messages that exceed 70% of default 128k context
|
||||
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
|
||||
warning = check_context_budget(messages)
|
||||
self.assertIsNotNone(warning)
|
||||
|
||||
|
||||
class TestStatusLine(unittest.TestCase):
|
||||
|
||||
def test_green_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("GREEN", line)
|
||||
|
||||
def test_red_status(self):
|
||||
tracker = ContextBudgetTracker(context_limit=10000)
|
||||
tracker.budget.used_tokens = 7600
|
||||
line = tracker.get_status_line()
|
||||
self.assertIn("RED", line)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
101
tests/test_credential_redact.py
Normal file
101
tests/test_credential_redact.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Tests for credential redaction
|
||||
|
||||
Issue: #839
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from tools.credential_redact import (
|
||||
CredentialRedactor,
|
||||
redact_credentials,
|
||||
redact_tool_output,
|
||||
should_mask_file,
|
||||
mask_sensitive_file,
|
||||
)
|
||||
|
||||
|
||||
class TestCredentialRedaction(unittest.TestCase):
|
||||
|
||||
def test_openai_key(self):
|
||||
text = "api_key=sk-abc123def456ghi789jkl012mno"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
self.assertNotIn("sk-abc123", redacted)
|
||||
|
||||
def test_github_token(self):
|
||||
text = "token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_bearer_token(self):
|
||||
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_password(self):
|
||||
text = "password: mySecretPassword123"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_aws_key(self):
|
||||
text = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_database_url(self):
|
||||
text = "DATABASE_URL=postgres://user:pass@localhost/db"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreater(count, 0)
|
||||
self.assertIn("REDACTED", redacted)
|
||||
|
||||
def test_clean_text_unchanged(self):
|
||||
text = "Hello world, this is a normal message"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertEqual(count, 0)
|
||||
self.assertEqual(redacted, text)
|
||||
|
||||
def test_multiple_credentials(self):
|
||||
text = "key1=sk-abc123def456ghi789jkl012mno and token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||
redacted, count = redact_credentials(text)
|
||||
self.assertGreaterEqual(count, 2)
|
||||
|
||||
|
||||
class TestToolOutputRedaction(unittest.TestCase):
|
||||
|
||||
def test_redaction_notice(self):
|
||||
output = "Running with key sk-abc123def456ghi789jkl012mno"
|
||||
redacted, notice = redact_tool_output("terminal", output)
|
||||
self.assertIn("REDACTED", notice)
|
||||
self.assertIn("terminal", notice)
|
||||
|
||||
def test_no_notice_when_clean(self):
|
||||
output = "Hello world"
|
||||
redacted, notice = redact_tool_output("terminal", output)
|
||||
self.assertEqual(notice, "")
|
||||
|
||||
|
||||
class TestSensitiveFileMasking(unittest.TestCase):
|
||||
|
||||
def test_env_file_detected(self):
|
||||
self.assertTrue(should_mask_file("/path/to/.env"))
|
||||
self.assertTrue(should_mask_file("/path/to/.env.local"))
|
||||
self.assertTrue(should_mask_file("/path/to/config.yaml"))
|
||||
|
||||
def test_normal_file_not_detected(self):
|
||||
self.assertFalse(should_mask_file("/path/to/readme.md"))
|
||||
self.assertFalse(should_mask_file("/path/to/code.py"))
|
||||
|
||||
def test_mask_env_file(self):
|
||||
content = "API_KEY=sk-abc123\nDATABASE_URL=postgres://u:p@h/d\nNORMAL=value"
|
||||
masked = mask_sensitive_file(content, ".env")
|
||||
self.assertIn("[REDACTED]", masked)
|
||||
self.assertIn("NORMAL=value", masked)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
79
tests/test_crisis_resources.py
Normal file
79
tests/test_crisis_resources.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Tests for 988 Crisis Lifeline integration (#673)."""
|
||||
|
||||
import pytest
|
||||
from agent.crisis_resources import (
|
||||
LIFELINE_988,
|
||||
LIFELINE_988_TEXT,
|
||||
LIFELINE_988_CHAT,
|
||||
LIFELINE_988_SPANISH,
|
||||
CRISIS_TEXT_LINE,
|
||||
EMERGENCY_911,
|
||||
ALL_RESOURCES,
|
||||
get_crisis_resources,
|
||||
format_crisis_resources,
|
||||
get_immediate_help_message,
|
||||
CrisisResource,
|
||||
)
|
||||
|
||||
|
||||
class TestCrisisResources:
|
||||
def test_988_phone(self):
|
||||
assert "988" in LIFELINE_988.contact
|
||||
assert "24/7" in LIFELINE_988.available
|
||||
|
||||
def test_988_text(self):
|
||||
assert "HOME" in LIFELINE_988_TEXT.contact
|
||||
assert "988" in LIFELINE_988_TEXT.contact
|
||||
|
||||
def test_988_chat(self):
|
||||
assert "988lifeline.org/chat" in LIFELINE_988_CHAT.url
|
||||
|
||||
def test_988_spanish(self):
|
||||
assert "1-888-628-9454" in LIFELINE_988_SPANISH.contact
|
||||
assert LIFELINE_988_SPANISH.language == "Spanish"
|
||||
|
||||
def test_crisis_text_line(self):
|
||||
assert "741741" in CRISIS_TEXT_LINE.contact
|
||||
|
||||
def test_911(self):
|
||||
assert "911" in EMERGENCY_911.contact
|
||||
|
||||
def test_all_resources_not_empty(self):
|
||||
assert len(ALL_RESOURCES) >= 5
|
||||
|
||||
|
||||
class TestGetResources:
|
||||
def test_returns_all_by_default(self):
|
||||
assert len(get_crisis_resources()) == len(ALL_RESOURCES)
|
||||
|
||||
def test_filter_english(self):
|
||||
english = get_crisis_resources("English")
|
||||
assert all(r.language == "English" for r in english)
|
||||
assert len(english) > 0
|
||||
|
||||
def test_filter_spanish(self):
|
||||
spanish = get_crisis_resources("Spanish")
|
||||
assert len(spanish) >= 1
|
||||
assert all(r.language == "Spanish" for r in spanish)
|
||||
|
||||
|
||||
class TestFormatting:
|
||||
def test_format_includes_988(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "988" in msg
|
||||
|
||||
def test_format_includes_741741(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "741741" in msg
|
||||
|
||||
def test_format_includes_911(self):
|
||||
msg = format_crisis_resources()
|
||||
assert "911" in msg
|
||||
|
||||
def test_immediate_help_includes_911_first(self):
|
||||
msg = get_immediate_help_message()
|
||||
assert msg.startswith("If you are in immediate danger")
|
||||
|
||||
def test_format_not_empty(self):
|
||||
msg = format_crisis_resources()
|
||||
assert len(msg) > 100
|
||||
274
tests/test_poka_yoke.py
Normal file
274
tests/test_poka_yoke.py
Normal file
@@ -0,0 +1,274 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
test_poka_yoke.py — Tests for the tool call validation firewall.
|
||||
|
||||
Covers: unknown tool, bad param type, missing required arg,
|
||||
extra unknown param, enum validation, closest-name suggestion.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||
|
||||
from tools.poka_yoke import (
|
||||
validate_tool_call,
|
||||
_find_closest_name,
|
||||
_validate_type,
|
||||
_truncate,
|
||||
)
|
||||
|
||||
|
||||
# ── Mock Registry ─────────────────────────────────────────────────────────────
|
||||
|
||||
class MockEntry:
|
||||
def __init__(self, name, schema):
|
||||
self.name = name
|
||||
self.schema = schema
|
||||
self.toolset = "test"
|
||||
|
||||
|
||||
MOCK_TOOLS = {
|
||||
"read_file": MockEntry("read_file", {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path"},
|
||||
"offset": {"type": "integer", "description": "Start line"},
|
||||
"limit": {"type": "integer", "description": "Max lines"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
}),
|
||||
"web_search": MockEntry("web_search", {
|
||||
"name": "web_search",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"max_results": {"type": "integer"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
}),
|
||||
"write_file": MockEntry("write_file", {
|
||||
"name": "write_file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"content": {"type": "string"},
|
||||
},
|
||||
"required": ["path", "content"],
|
||||
},
|
||||
}),
|
||||
"terminal": MockEntry("terminal", {
|
||||
"name": "terminal",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"},
|
||||
"timeout": {"type": "integer"},
|
||||
"background": {"type": "boolean"},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
def _mock_registry():
|
||||
"""Create a mock registry."""
|
||||
mock_reg = MagicMock()
|
||||
mock_reg.get_entry = lambda name: MOCK_TOOLS.get(name)
|
||||
mock_reg.get_all_tool_names = lambda: list(MOCK_TOOLS.keys())
|
||||
return mock_reg
|
||||
|
||||
|
||||
# ── Test: Unknown Tool ────────────────────────────────────────────────────────
|
||||
|
||||
class TestUnknownTool:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_tool_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("nonexistent_tool", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert len(msgs) > 0
|
||||
assert "nonexistent_tool" in msgs[0]
|
||||
assert "Unknown tool" in msgs[0]
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_tool_lists_available(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("foo", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert "read_file" in msgs[0]
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_close_name_suggests_correction(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = None
|
||||
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("readfile", {})
|
||||
|
||||
assert "read_file" in msgs[0]
|
||||
assert name == "read_file"
|
||||
|
||||
|
||||
# ── Test: Missing Required Args ───────────────────────────────────────────────
|
||||
|
||||
class TestMissingRequired:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_missing_required_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("read_file", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert any("Missing required" in m for m in msgs)
|
||||
assert any("'path'" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_multiple_missing_required(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call("write_file", {})
|
||||
|
||||
assert is_valid is False
|
||||
assert any("'path'" in m for m in msgs)
|
||||
assert any("'content'" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_required_present_passes(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
# ── Test: Type Validation ─────────────────────────────────────────────────────
|
||||
|
||||
class TestTypeValidation:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_wrong_type_rejected(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": "not_a_number"}
|
||||
)
|
||||
|
||||
assert is_valid is False
|
||||
assert any("offset" in m and "integer" in m for m in msgs)
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_string_to_int_coercion(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": "42"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert params["offset"] == 42
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_boolean_coercion(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["terminal"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"terminal", {"command": "ls", "background": "true"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert params["background"] is True
|
||||
|
||||
|
||||
# ── Test: Unknown Parameters ──────────────────────────────────────────────────
|
||||
|
||||
class TestUnknownParams:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_unknown_param_removed(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "bogus_param": "value"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert params is not None
|
||||
assert "bogus_param" not in params
|
||||
assert "path" in params
|
||||
assert any("Unknown parameter" in m for m in msgs)
|
||||
|
||||
|
||||
# ── Test: Valid Calls Pass Through ────────────────────────────────────────────
|
||||
|
||||
class TestValidCalls:
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_valid_read_file(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"read_file", {"path": "test.txt", "offset": 1, "limit": 100}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
assert name is None
|
||||
assert params is None
|
||||
assert msgs == []
|
||||
|
||||
@patch("tools.poka_yoke.registry")
|
||||
def test_valid_write_file(self, mock_reg):
|
||||
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||
|
||||
is_valid, name, params, msgs = validate_tool_call(
|
||||
"write_file", {"path": "out.txt", "content": "hello"}
|
||||
)
|
||||
|
||||
assert is_valid is True
|
||||
|
||||
|
||||
# ── Test: Helper Functions ────────────────────────────────────────────────────
|
||||
|
||||
class TestHelpers:
|
||||
def test_find_closest_exact_prefix(self):
|
||||
assert _find_closest_name("readfil", ["read_file", "write_file"]) == "read_file"
|
||||
|
||||
def test_find_closest_substring(self):
|
||||
assert _find_closest_name("file", ["read_file", "web_search"]) == "read_file"
|
||||
|
||||
def test_find_closest_no_match(self):
|
||||
assert _find_closest_name("xyzzy", ["read_file", "write_file"]) is None
|
||||
|
||||
def test_validate_type_string(self):
|
||||
ok, val = _validate_type("x", "hello", "string")
|
||||
assert ok is True
|
||||
|
||||
def test_validate_type_int_coercion(self):
|
||||
ok, val = _validate_type("x", "42", "integer")
|
||||
assert ok is True
|
||||
assert val == 42
|
||||
|
||||
def test_validate_type_int_bad(self):
|
||||
ok, val = _validate_type("x", "not_int", "integer")
|
||||
assert ok is False
|
||||
|
||||
def test_truncate(self):
|
||||
assert _truncate("hello", 10) == "hello"
|
||||
assert _truncate("hello world", 8) == "hello..."
|
||||
76
tests/test_profile_isolation.py
Normal file
76
tests/test_profile_isolation.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for profile session isolation (#891)."""
|
||||
|
||||
import sys
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
# Override paths for testing
|
||||
import agent.profile_isolation as iso_mod
|
||||
_test_dir = Path(tempfile.mkdtemp())
|
||||
iso_mod.PROFILE_TAGS_FILE = _test_dir / "tags.json"
|
||||
|
||||
|
||||
def test_tag_session():
|
||||
"""Session gets tagged with profile."""
|
||||
profile = iso_mod.tag_session("sess-1", "sprint")
|
||||
assert profile == "sprint"
|
||||
assert iso_mod.get_session_profile("sess-1") == "sprint"
|
||||
|
||||
|
||||
def test_default_profile():
|
||||
"""Sessions tagged with default when no profile specified."""
|
||||
profile = iso_mod.tag_session("sess-2")
|
||||
assert profile is not None
|
||||
|
||||
|
||||
def test_get_session_profile():
|
||||
"""Can retrieve profile for tagged session."""
|
||||
iso_mod.tag_session("sess-3", "fenrir")
|
||||
assert iso_mod.get_session_profile("sess-3") == "fenrir"
|
||||
|
||||
|
||||
def test_untagged_returns_none():
|
||||
"""Untagged session returns None."""
|
||||
assert iso_mod.get_session_profile("nonexistent") is None
|
||||
|
||||
|
||||
def test_profile_stats():
|
||||
"""Stats reflect tagged sessions."""
|
||||
iso_mod.tag_session("s1", "default")
|
||||
iso_mod.tag_session("s2", "sprint")
|
||||
iso_mod.tag_session("s3", "sprint")
|
||||
stats = iso_mod.get_profile_stats()
|
||||
assert stats["total_tagged_sessions"] >= 3
|
||||
assert "sprint" in stats["profile_counts"]
|
||||
|
||||
|
||||
def test_filter_sessions():
|
||||
"""Filter returns only matching profile sessions."""
|
||||
iso_mod.tag_session("filter-1", "alpha")
|
||||
iso_mod.tag_session("filter-2", "beta")
|
||||
iso_mod.tag_session("filter-3", "alpha")
|
||||
|
||||
sessions = [
|
||||
{"session_id": "filter-1"},
|
||||
{"session_id": "filter-2"},
|
||||
{"session_id": "filter-3"},
|
||||
]
|
||||
|
||||
filtered = iso_mod.filter_sessions_by_profile(sessions, "alpha")
|
||||
ids = [s["session_id"] for s in filtered]
|
||||
assert "filter-1" in ids
|
||||
assert "filter-3" in ids
|
||||
assert "filter-2" not in ids
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = [test_tag_session, test_default_profile, test_get_session_profile,
|
||||
test_untagged_returns_none, test_profile_stats, test_filter_sessions]
|
||||
for t in tests:
|
||||
print(f"Running {t.__name__}...")
|
||||
t()
|
||||
print(" PASS")
|
||||
print("\nAll tests passed.")
|
||||
302
tests/test_skill_manager_autorevert.py
Normal file
302
tests/test_skill_manager_autorevert.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""
|
||||
Integration tests for poka-yoke auto-revert on incomplete skill edits (#923).
|
||||
|
||||
Verifies the transactional write-validate-commit-or-rollback pattern:
|
||||
- Backup created before every write
|
||||
- Post-write validation triggers revert on corrupted/empty file
|
||||
- Successful writes clean up the backup
|
||||
- At most MAX_BACKUPS_PER_FILE backups retained per file
|
||||
"""
|
||||
|
||||
import time
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from tools.skill_manager_tool import (
|
||||
MAX_BACKUPS_PER_FILE,
|
||||
_backup_skill_file,
|
||||
_cleanup_old_backups,
|
||||
_edit_skill,
|
||||
_patch_skill,
|
||||
_revert_from_backup,
|
||||
_validate_written_file,
|
||||
_write_file,
|
||||
)
|
||||
|
||||
|
||||
VALID_SKILL_MD = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: A skill for testing auto-revert
|
||||
---
|
||||
|
||||
## Overview
|
||||
Test skill body content.
|
||||
"""
|
||||
|
||||
VALID_UPDATED_MD = """\
|
||||
---
|
||||
name: test-skill
|
||||
description: Updated description
|
||||
---
|
||||
|
||||
## Overview
|
||||
Updated test skill body.
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_skill(tmp_path: Path, content: str = VALID_SKILL_MD) -> Path:
|
||||
"""Write a minimal SKILL.md in *tmp_path* and return its path."""
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text(content, encoding="utf-8")
|
||||
return skill_md
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _backup_skill_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBackupSkillFile:
|
||||
def test_creates_bak_file(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
backup = _backup_skill_file(skill_md)
|
||||
assert backup is not None
|
||||
assert backup.exists()
|
||||
assert ".bak." in backup.name
|
||||
|
||||
def test_backup_preserves_content(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
backup = _backup_skill_file(skill_md)
|
||||
assert backup.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||
|
||||
def test_no_backup_for_nonexistent_file(self, tmp_path):
|
||||
missing = tmp_path / "SKILL.md"
|
||||
assert _backup_skill_file(missing) is None
|
||||
|
||||
def test_backup_name_contains_timestamp(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
before = int(time.time())
|
||||
backup = _backup_skill_file(skill_md)
|
||||
after = int(time.time())
|
||||
ts = int(backup.name.split(".bak.")[-1])
|
||||
assert before <= ts <= after
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _cleanup_old_backups
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCleanupOldBackups:
|
||||
def _create_backups(self, skill_md: Path, n: int) -> list:
|
||||
backups = []
|
||||
for i in range(n):
|
||||
bp = skill_md.parent / f"{skill_md.name}.bak.{1000 + i}"
|
||||
bp.write_text("backup content", encoding="utf-8")
|
||||
backups.append(bp)
|
||||
return backups
|
||||
|
||||
def test_prunes_excess_backups(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE + 2)
|
||||
_cleanup_old_backups(skill_md)
|
||||
remaining = list(tmp_path.glob(f"SKILL.md.bak.*"))
|
||||
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||
|
||||
def test_keeps_backups_within_limit(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE)
|
||||
_cleanup_old_backups(skill_md)
|
||||
remaining = list(tmp_path.glob("SKILL.md.bak.*"))
|
||||
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||
|
||||
def test_noop_when_no_backups(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
_cleanup_old_backups(skill_md) # should not raise
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _validate_written_file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestValidateWrittenFile:
|
||||
def test_valid_skill_md(self, tmp_path):
|
||||
skill_md = _make_skill(tmp_path)
|
||||
assert _validate_written_file(skill_md, is_skill_md=True) is None
|
||||
|
||||
def test_empty_file_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=False)
|
||||
assert err is not None
|
||||
assert "empty" in err.lower()
|
||||
|
||||
def test_broken_frontmatter_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("Not a skill\nno frontmatter\n", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
assert err is not None
|
||||
|
||||
def test_missing_required_field_fails(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("---\ndescription: no name\n---\nbody\n", encoding="utf-8")
|
||||
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
assert err is not None
|
||||
assert "name" in err.lower()
|
||||
|
||||
def test_missing_file_returns_error(self, tmp_path):
|
||||
missing = tmp_path / "SKILL.md"
|
||||
err = _validate_written_file(missing, is_skill_md=False)
|
||||
assert err is not None
|
||||
|
||||
def test_non_skill_md_only_checks_emptiness(self, tmp_path):
|
||||
ref = tmp_path / "references" / "guide.md"
|
||||
ref.parent.mkdir()
|
||||
ref.write_text("# Guide\nsome content\n", encoding="utf-8")
|
||||
assert _validate_written_file(ref, is_skill_md=False) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests: _revert_from_backup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRevertFromBackup:
|
||||
def test_restores_from_backup(self, tmp_path):
|
||||
original = "original content"
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text(original, encoding="utf-8")
|
||||
backup = tmp_path / "SKILL.md.bak.99999"
|
||||
backup.write_text(original, encoding="utf-8")
|
||||
|
||||
skill_md.write_text("corrupted content", encoding="utf-8")
|
||||
_revert_from_backup(skill_md, backup)
|
||||
assert skill_md.read_text(encoding="utf-8") == original
|
||||
|
||||
def test_removes_file_when_no_backup(self, tmp_path):
|
||||
skill_md = tmp_path / "SKILL.md"
|
||||
skill_md.write_text("corrupted", encoding="utf-8")
|
||||
_revert_from_backup(skill_md, None)
|
||||
assert not skill_md.exists()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: _edit_skill auto-revert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEditSkillAutoRevert:
|
||||
@pytest.fixture
|
||||
def skill_dir(self, tmp_path):
|
||||
"""Create a minimal skill directory and patch _find_skill."""
|
||||
d = tmp_path / "test-skill"
|
||||
d.mkdir()
|
||||
skill_md = d / "SKILL.md"
|
||||
skill_md.write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||
return d
|
||||
|
||||
def test_successful_edit_removes_backup(self, skill_dir):
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is True
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) == 0
|
||||
|
||||
def test_revert_when_post_write_validation_fails(self, skill_dir):
|
||||
"""Simulate a write that produces an empty file on disk."""
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
|
||||
def corrupt_write(path, content, **kw):
|
||||
# Write an empty file to simulate truncation
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "reverted" in result["error"].lower()
|
||||
# Original content restored
|
||||
assert skill_md.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||
|
||||
def test_backup_preserved_after_revert(self, skill_dir):
|
||||
"""A .bak file should survive when the edit is reverted (debugging aid)."""
|
||||
def corrupt_write(path, content, **kw):
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
_edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) == 1
|
||||
|
||||
def test_max_backups_enforced_after_multiple_edits(self, skill_dir):
|
||||
"""After many successful edits, at most MAX_BACKUPS_PER_FILE .bak files remain."""
|
||||
n = MAX_BACKUPS_PER_FILE + 4
|
||||
for i in range(n):
|
||||
# Plant stale backup files to simulate prior runs
|
||||
bp = skill_dir / f"SKILL.md.bak.{1000 + i}"
|
||||
bp.write_text("old backup", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||
|
||||
assert result["success"] is True
|
||||
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||
assert len(backups) <= MAX_BACKUPS_PER_FILE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests: _patch_skill auto-revert
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestPatchSkillAutoRevert:
|
||||
@pytest.fixture
|
||||
def skill_dir(self, tmp_path):
|
||||
d = tmp_path / "test-skill"
|
||||
d.mkdir()
|
||||
(d / "SKILL.md").write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||
return d
|
||||
|
||||
def test_successful_patch_removes_backup(self, skill_dir):
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _patch_skill(
|
||||
"test-skill",
|
||||
"A skill for testing auto-revert",
|
||||
"Updated description",
|
||||
)
|
||||
|
||||
assert result["success"] is True
|
||||
assert len(list(skill_dir.glob("SKILL.md.bak.*"))) == 0
|
||||
|
||||
def test_revert_on_corrupt_write(self, skill_dir):
|
||||
skill_md = skill_dir / "SKILL.md"
|
||||
original = skill_md.read_text(encoding="utf-8")
|
||||
|
||||
def corrupt_write(path, content, **kw):
|
||||
path.write_text("", encoding="utf-8")
|
||||
|
||||
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||
mock_find.return_value = {"path": skill_dir}
|
||||
result = _patch_skill(
|
||||
"test-skill",
|
||||
"A skill for testing",
|
||||
"A skill for testing auto-revert",
|
||||
)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "reverted" in result["error"].lower()
|
||||
assert skill_md.read_text(encoding="utf-8") == original
|
||||
82
tests/test_syntax_validation.py
Normal file
82
tests/test_syntax_validation.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Tests for Python syntax validation in execute_code."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Import the validation function directly
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||
from tools.code_execution_tool import _validate_python_syntax
|
||||
|
||||
|
||||
class TestValidatePythonSyntax:
|
||||
"""Test _validate_python_syntax catches errors before subprocess spawn."""
|
||||
|
||||
def test_valid_code_returns_none(self):
|
||||
assert _validate_python_syntax("print('hello')") is None
|
||||
|
||||
def test_valid_multiline_returns_none(self):
|
||||
code = """
|
||||
import os
|
||||
def foo():
|
||||
return 42
|
||||
result = foo()
|
||||
"""
|
||||
assert _validate_python_syntax(code) is None
|
||||
|
||||
def test_syntax_error_detected(self):
|
||||
result = _validate_python_syntax("def foo(
|
||||
")
|
||||
assert result is not None
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert "line" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_missing_colon(self):
|
||||
result = _validate_python_syntax("def foo()
|
||||
pass")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert data["line"] == 1
|
||||
|
||||
def test_unmatched_paren(self):
|
||||
result = _validate_python_syntax("print('hello'")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
|
||||
def test_indentation_error(self):
|
||||
result = _validate_python_syntax("def foo():
|
||||
pass")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
assert data["line"] == 2
|
||||
|
||||
def test_invalid_character(self):
|
||||
result = _validate_python_syntax("x = 5 √ 2")
|
||||
data = json.loads(result)
|
||||
assert data["syntax_error"] is True
|
||||
|
||||
def test_error_format_has_required_fields(self):
|
||||
result = _validate_python_syntax("def(
|
||||
")
|
||||
data = json.loads(result)
|
||||
assert "error" in data
|
||||
assert "syntax_error" in data
|
||||
assert "line" in data
|
||||
assert "offset" in data
|
||||
assert "message" in data
|
||||
|
||||
def test_empty_string_returns_none(self):
|
||||
# Empty code is caught by the guard before validation
|
||||
# But if called directly, ast.parse("") is valid
|
||||
assert _validate_python_syntax("") is None
|
||||
|
||||
def test_comment_only_returns_none(self):
|
||||
assert _validate_python_syntax("# just a comment") is None
|
||||
|
||||
def test_complex_valid_code(self):
|
||||
code =
|
||||
58
tests/test_time_aware_routing.py
Normal file
58
tests/test_time_aware_routing.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Tests for time-aware model routing."""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.time_aware_routing import (
|
||||
resolve_time_aware_model,
|
||||
get_hour_error_rate,
|
||||
is_off_hours,
|
||||
get_routing_report,
|
||||
)
|
||||
|
||||
|
||||
class TestErrorRates:
|
||||
def test_evening_high_error(self):
|
||||
assert get_hour_error_rate(18) == 9.4
|
||||
assert get_hour_error_rate(19) == 8.1
|
||||
|
||||
def test_morning_low_error(self):
|
||||
assert get_hour_error_rate(9) == 4.0
|
||||
assert get_hour_error_rate(12) == 4.0
|
||||
|
||||
def test_default_for_unknown(self):
|
||||
assert get_hour_error_rate(15) == 4.0
|
||||
|
||||
|
||||
class TestOffHours:
|
||||
def test_evening_is_off_hours(self):
|
||||
assert is_off_hours(20) is True
|
||||
assert is_off_hours(2) is True
|
||||
|
||||
def test_business_hours_not_off(self):
|
||||
assert is_off_hours(9) is False
|
||||
assert is_off_hours(14) is False
|
||||
|
||||
|
||||
class TestRouting:
|
||||
def test_interactive_uses_base_model(self):
|
||||
d = resolve_time_aware_model("my-model", "my-provider", is_cron=False, hour=18)
|
||||
assert d.model == "my-model"
|
||||
assert "Interactive" in d.reason
|
||||
|
||||
def test_cron_low_error_uses_base(self):
|
||||
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=10)
|
||||
assert d.model == "cheap-model"
|
||||
|
||||
def test_cron_high_error_upgrades(self):
|
||||
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=18)
|
||||
assert d.model != "cheap-model"
|
||||
assert d.is_off_hours is True
|
||||
|
||||
def test_routing_report(self):
|
||||
report = get_routing_report()
|
||||
assert "Time-Aware Model Routing" in report
|
||||
assert "18:00" in report
|
||||
237
tests/test_token_budget.py
Normal file
237
tests/test_token_budget.py
Normal file
@@ -0,0 +1,237 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tests for agent/token_budget.py — Poka-yoke context overflow guard.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from agent.token_budget import (
|
||||
TokenBudget,
|
||||
BudgetLevel,
|
||||
BudgetStatus,
|
||||
WARN_PERCENT,
|
||||
CAUTION_PERCENT,
|
||||
CRITICAL_PERCENT,
|
||||
STOP_PERCENT,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def budget():
|
||||
"""Standard 128K context budget."""
|
||||
return TokenBudget(context_length=128_000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_budget():
|
||||
"""4K context for tight testing."""
|
||||
return TokenBudget(context_length=4_000)
|
||||
|
||||
|
||||
# ── Threshold Levels ──────────────────────────────────────────────────
|
||||
|
||||
class TestThresholds:
|
||||
def test_normal_below_60(self, budget):
|
||||
budget.update(50_000) # 39%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.NORMAL
|
||||
assert not status.should_compress
|
||||
assert not status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_warning_at_60(self, budget):
|
||||
budget.update(int(128_000 * 0.62)) # 62%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.WARNING
|
||||
assert not status.should_compress
|
||||
assert not status.should_block_tools
|
||||
|
||||
def test_caution_at_80(self, budget):
|
||||
budget.update(int(128_000 * 0.82)) # 82%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.CAUTION
|
||||
assert status.should_compress
|
||||
assert not status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_critical_at_90(self, budget):
|
||||
budget.update(int(128_000 * 0.91)) # 91%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.CRITICAL
|
||||
assert status.should_compress
|
||||
assert status.should_block_tools
|
||||
assert not status.should_terminate
|
||||
|
||||
def test_stop_at_95(self, budget):
|
||||
budget.update(int(128_000 * 0.96)) # 96%
|
||||
status = budget.check()
|
||||
assert status.level == BudgetLevel.STOP
|
||||
assert status.should_compress
|
||||
assert status.should_block_tools
|
||||
assert status.should_terminate
|
||||
|
||||
def test_small_context_thresholds(self, small_budget):
|
||||
# 4K * 0.60 = 2400
|
||||
small_budget.update(2450)
|
||||
assert small_budget.check().level == BudgetLevel.WARNING
|
||||
|
||||
small_budget.update(3250) # 4K * 0.81
|
||||
assert small_budget.check().level == BudgetLevel.CAUTION
|
||||
|
||||
small_budget.update(3650) # 4K * 0.91
|
||||
assert small_budget.check().level == BudgetLevel.CRITICAL
|
||||
|
||||
small_budget.update(3850) # 4K * 0.96
|
||||
assert small_budget.check().level == BudgetLevel.STOP
|
||||
|
||||
|
||||
# ── Convenience Methods ───────────────────────────────────────────────
|
||||
|
||||
class TestConvenienceMethods:
|
||||
def test_should_compress(self, budget):
|
||||
budget.update(int(128_000 * 0.79))
|
||||
assert not budget.should_compress()
|
||||
budget.update(int(128_000 * 0.80))
|
||||
assert budget.should_compress()
|
||||
|
||||
def test_should_block_tools(self, budget):
|
||||
budget.update(int(128_000 * 0.89))
|
||||
assert not budget.should_block_tools()
|
||||
budget.update(int(128_000 * 0.90))
|
||||
assert budget.should_block_tools()
|
||||
|
||||
def test_should_terminate(self, budget):
|
||||
budget.update(int(128_000 * 0.94))
|
||||
assert not budget.should_terminate()
|
||||
budget.update(int(128_000 * 0.95))
|
||||
assert budget.should_terminate()
|
||||
|
||||
|
||||
# ── Tool Output Budgeting ─────────────────────────────────────────────
|
||||
|
||||
class TestToolOutputBudget:
|
||||
def test_normal_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
assert budget.tool_output_budget() == 50_000
|
||||
|
||||
def test_warning_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.65))
|
||||
assert budget.tool_output_budget() == 20_000
|
||||
|
||||
def test_caution_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.85))
|
||||
assert budget.tool_output_budget() == 8_000
|
||||
|
||||
def test_critical_budget(self, budget):
|
||||
budget.update(int(128_000 * 0.92))
|
||||
assert budget.tool_output_budget() == 2_000
|
||||
|
||||
def test_truncate_short_unchanged(self, budget):
|
||||
result = budget.truncate_tool_output("short text", max_chars=1000)
|
||||
assert result == "short text"
|
||||
|
||||
def test_truncate_long(self, budget):
|
||||
long_text = "A" * 100_000
|
||||
result = budget.truncate_tool_output(long_text, max_chars=5_000)
|
||||
assert len(result) <= 5_100 # small overhead for notice
|
||||
assert "truncated" in result
|
||||
assert "A" in result[:2500] # head preserved
|
||||
assert "A" in result[-2500:] # tail preserved
|
||||
|
||||
def test_truncate_very_small(self, budget):
|
||||
long_text = "X" * 1000
|
||||
result = budget.truncate_tool_output(long_text, max_chars=50)
|
||||
assert len(result) <= 50 + 20
|
||||
assert "truncated" in result
|
||||
|
||||
|
||||
# ── Growth Tracking ───────────────────────────────────────────────────
|
||||
|
||||
class TestGrowthTracking:
|
||||
def test_growth_rate(self, budget):
|
||||
budget.update(10_000)
|
||||
budget.update(15_000)
|
||||
budget.update(20_000)
|
||||
assert budget.growth_rate() == 5_000.0
|
||||
|
||||
def test_turns_remaining(self, budget):
|
||||
budget.update(10_000)
|
||||
budget.update(15_000)
|
||||
budget.update(20_000)
|
||||
# rate=5000, remaining=108000, turns=~21
|
||||
turns = budget.turns_remaining()
|
||||
assert turns is not None
|
||||
assert 18 <= turns <= 24
|
||||
|
||||
def test_no_history(self, budget):
|
||||
assert budget.growth_rate() is None
|
||||
assert budget.turns_remaining() is None
|
||||
|
||||
|
||||
# ── Status Indicators ─────────────────────────────────────────────────
|
||||
|
||||
class TestStatusIndicators:
|
||||
def test_indicator_normal(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
status = budget.check()
|
||||
indicator = status.to_indicator()
|
||||
assert "50" in indicator
|
||||
|
||||
def test_indicator_warning(self, budget):
|
||||
budget.update(int(128_000 * 0.65))
|
||||
status = budget.check()
|
||||
indicator = status.to_indicator()
|
||||
assert "\u26a0" in indicator or "65" in indicator
|
||||
|
||||
def test_bar(self, budget):
|
||||
budget.update(int(128_000 * 0.50))
|
||||
status = budget.check()
|
||||
bar = status.to_bar()
|
||||
assert "50" in bar
|
||||
|
||||
def test_summary(self, budget):
|
||||
budget.update(50_000)
|
||||
summary = budget.summary()
|
||||
assert "50,000" in summary
|
||||
assert "128,000" in summary
|
||||
assert "NORMAL" in summary
|
||||
|
||||
|
||||
# ── Reset ─────────────────────────────────────────────────────────────
|
||||
|
||||
class TestReset:
|
||||
def test_reset_clears_state(self, budget):
|
||||
budget.update(int(128_000 * 0.90))
|
||||
budget.reset()
|
||||
assert budget.tokens_used == 0
|
||||
assert budget.check().level == BudgetLevel.NORMAL
|
||||
assert budget.growth_rate() is None
|
||||
|
||||
|
||||
# ── Edge Cases ────────────────────────────────────────────────────────
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_exact_threshold_boundary(self, budget):
|
||||
# Exactly at 60%
|
||||
budget.update(int(128_000 * 0.60))
|
||||
assert budget.check().level == BudgetLevel.WARNING
|
||||
|
||||
def test_zero_context(self):
|
||||
budget = TokenBudget(context_length=0)
|
||||
status = budget.check()
|
||||
assert status.percent_used == 0
|
||||
|
||||
def test_remaining_for_response(self, budget):
|
||||
budget.update(100_000)
|
||||
remaining = budget.remaining_for_response()
|
||||
# 128000 - 100000 - 6400 (5% reserve) = 21600
|
||||
assert remaining > 0
|
||||
assert remaining < 128_000
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
76
tests/test_tool_fixation_detector.py
Normal file
76
tests/test_tool_fixation_detector.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""Tests for tool fixation detection."""
|
||||
|
||||
import pytest
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from agent.tool_fixation_detector import ToolFixationDetector, get_fixation_detector
|
||||
|
||||
|
||||
class TestFixationDetection:
|
||||
def test_no_fixation_below_threshold(self):
|
||||
d = ToolFixationDetector(threshold=5)
|
||||
for i in range(4):
|
||||
assert d.record("execute_code") is None
|
||||
|
||||
def test_fixation_at_threshold(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
nudge = d.record("execute_code")
|
||||
assert nudge is not None
|
||||
assert "execute_code" in nudge
|
||||
assert "3 times" in nudge
|
||||
|
||||
def test_fixation_above_threshold(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
d.record("execute_code") # threshold hit
|
||||
nudge = d.record("execute_code") # still nudging
|
||||
assert nudge is not None
|
||||
|
||||
def test_streak_resets_on_different_tool(self):
|
||||
d = ToolFixationDetector(threshold=3)
|
||||
d.record("execute_code")
|
||||
d.record("execute_code")
|
||||
d.record("terminal") # breaks streak
|
||||
assert d._streak_count == 1
|
||||
assert d._current_streak == "terminal"
|
||||
|
||||
def test_nudges_sent_counter(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("a")
|
||||
d.record("a") # nudge 1
|
||||
d.record("a") # nudge 2
|
||||
assert d.nudges_sent == 2
|
||||
|
||||
def test_events_recorded(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
assert len(d.events) == 1
|
||||
assert d.events[0].tool_name == "x"
|
||||
assert d.events[0].streak_length == 2
|
||||
|
||||
def test_report(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
report = d.format_report()
|
||||
assert "x" in report
|
||||
|
||||
def test_reset(self):
|
||||
d = ToolFixationDetector(threshold=2)
|
||||
d.record("x")
|
||||
d.record("x")
|
||||
d.reset()
|
||||
assert d._streak_count == 0
|
||||
assert d._current_streak == ""
|
||||
|
||||
def test_singleton(self):
|
||||
d1 = get_fixation_detector()
|
||||
d2 = get_fixation_detector()
|
||||
assert d1 is d2
|
||||
67
tests/test_tool_validator.py
Normal file
67
tests/test_tool_validator.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
Tests for tool hallucination detection (#922).
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from tools.tool_validator import ToolHallucinationDetector, ValidationSeverity
|
||||
|
||||
|
||||
class TestToolHallucinationDetector:
|
||||
def setup_method(self):
|
||||
self.detector = ToolHallucinationDetector()
|
||||
self.detector.register_tool("read_file", {
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
"encoding": {"type": "string"},
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
})
|
||||
|
||||
def test_valid_tool_call(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt"})
|
||||
assert result.valid is True
|
||||
assert len(result.blocking_issues) == 0
|
||||
|
||||
def test_unknown_tool(self):
|
||||
result = self.detector.validate_tool_call("hallucinated_tool", {})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "UNKNOWN_TOOL" for i in result.issues)
|
||||
|
||||
def test_missing_required_param(self):
|
||||
result = self.detector.validate_tool_call("read_file", {})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "MISSING_REQUIRED" for i in result.issues)
|
||||
|
||||
def test_wrong_type(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": 123})
|
||||
assert result.valid is False
|
||||
assert any(i.code == "WRONG_TYPE" for i in result.issues)
|
||||
|
||||
def test_unknown_param_warning(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt", "unknown": "value"})
|
||||
assert result.valid is True # Warning, not blocking
|
||||
assert any(i.code == "UNKNOWN_PARAM" for i in result.issues)
|
||||
|
||||
def test_placeholder_detection(self):
|
||||
result = self.detector.validate_tool_call("read_file", {"path": "<placeholder>"})
|
||||
assert any(i.code == "PLACEHOLDER_VALUE" for i in result.issues)
|
||||
|
||||
def test_rejection_stats(self):
|
||||
self.detector.validate_tool_call("unknown_tool", {})
|
||||
self.detector.validate_tool_call("read_file", {})
|
||||
stats = self.detector.get_rejection_stats()
|
||||
assert stats["total"] >= 2
|
||||
|
||||
def test_rejection_response(self):
|
||||
from tools.tool_validator import create_rejection_response
|
||||
result = self.detector.validate_tool_call("unknown_tool", {})
|
||||
response = create_rejection_response(result)
|
||||
assert response["role"] == "tool"
|
||||
assert "rejected" in response["content"].lower()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
@@ -28,6 +28,7 @@ Platform: Linux / macOS only (Unix domain sockets for local). Disabled on Window
|
||||
Remote execution additionally requires Python 3 in the terminal backend.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -883,6 +884,42 @@ def _execute_remote(
|
||||
return json.dumps(result, ensure_ascii=False)
|
||||
|
||||
|
||||
|
||||
def _validate_python_syntax(code: str) -> Optional[str]:
|
||||
"""Validate Python syntax before subprocess spawn.
|
||||
|
||||
Runs ast.parse() in-process (sub-millisecond) to catch syntax errors
|
||||
before wasting time spawning a sandboxed subprocess.
|
||||
|
||||
Returns:
|
||||
JSON error string with line, offset, message if syntax is invalid.
|
||||
None if syntax is valid.
|
||||
"""
|
||||
try:
|
||||
ast.parse(code)
|
||||
return None
|
||||
except SyntaxError as exc:
|
||||
# Build context: show offending line with caret
|
||||
lines = code.split("\n")
|
||||
error_line = lines[exc.lineno - 1] if exc.lineno and exc.lineno <= len(lines) else ""
|
||||
context = ""
|
||||
if error_line:
|
||||
context = f"\n {error_line}"
|
||||
if exc.offset:
|
||||
context += f"\n {' ' * (exc.offset - 1)}^"
|
||||
|
||||
return json.dumps({
|
||||
"error": f"Python syntax error on line {exc.lineno}: {exc.msg}{context}",
|
||||
"syntax_error": True,
|
||||
"line": exc.lineno,
|
||||
"offset": exc.offset,
|
||||
"message": exc.msg,
|
||||
})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -916,6 +953,11 @@ def execute_code(
|
||||
if not code or not code.strip():
|
||||
return tool_error("No code provided.")
|
||||
|
||||
# Syntax check before subprocess spawn (catches ~15% of errors in <1ms)
|
||||
syntax_error = _validate_python_syntax(code)
|
||||
if syntax_error:
|
||||
return syntax_error
|
||||
|
||||
# Dispatch: remote backends use file-based RPC, local uses UDS
|
||||
from tools.terminal_tool import _get_env_config
|
||||
env_type = _get_env_config()["env_type"]
|
||||
|
||||
183
tools/credential_redact.py
Normal file
183
tools/credential_redact.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Credential Redaction — Block silent credential exposure in tool outputs
|
||||
|
||||
Poka-yoke: Prevent API keys, tokens, passwords from leaking into context.
|
||||
|
||||
Issue: #839
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path.home() / ".hermes"
|
||||
AUDIT_DIR = HERMES_HOME / "audit"
|
||||
|
||||
# Credential patterns to detect and redact
|
||||
CREDENTIAL_PATTERNS = [
|
||||
# API keys
|
||||
(r"sk-[a-zA-Z0-9]{20,}", "[REDACTED: OpenAI API key]"),
|
||||
(r"sk-ant-[a-zA-Z0-9-]{20,}", "[REDACTED: Anthropic API key]"),
|
||||
(r"ghp_[a-zA-Z0-9]{36}", "[REDACTED: GitHub token]"),
|
||||
(r"gho_[a-zA-Z0-9]{36}", "[REDACTED: GitHub OAuth token]"),
|
||||
(r"glpat-[a-zA-Z0-9-]{20,}", "[REDACTED: GitLab token]"),
|
||||
|
||||
# Bearer tokens
|
||||
(r"Bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||
(r"bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||
|
||||
# Generic tokens/passwords
|
||||
(r"(?:token|TOKEN|Token)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Token]"),
|
||||
(r"(?:password|PASSWORD|Password)[:=]\s*["']?[^\s"']{8,}["']?", "[REDACTED: Password]"),
|
||||
(r"(?:secret|SECRET|Secret)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Secret]"),
|
||||
(r"(?:api_key|API_KEY|apiKey|ApiKey)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: API key]"),
|
||||
|
||||
# AWS keys
|
||||
(r"AKIA[0-9A-Z]{16}", "[REDACTED: AWS access key]"),
|
||||
(r"(?:aws_secret_access_key|AWS_SECRET_ACCESS_KEY)[:=]\s*["']?[a-zA-Z0-9/+=]{40}["']?", "[REDACTED: AWS secret]"),
|
||||
|
||||
# Private keys
|
||||
(r"-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----", "[REDACTED: Private key header]"),
|
||||
|
||||
# Connection strings
|
||||
(r"(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@[^\s]+", "[REDACTED: Database connection string]"),
|
||||
]
|
||||
|
||||
# Files that should trigger auto-masking
|
||||
SENSITIVE_FILE_PATTERNS = [
|
||||
r"\.env$",
|
||||
r"\.env\.",
|
||||
r"\.secret",
|
||||
r"credentials",
|
||||
r"\.token",
|
||||
r"config\.yaml$",
|
||||
r"config\.yml$",
|
||||
r"config\.json$",
|
||||
r"\.netrc$",
|
||||
r"\.pgpass$",
|
||||
]
|
||||
|
||||
|
||||
class CredentialRedactor:
|
||||
"""Redact credentials from text."""
|
||||
|
||||
def __init__(self, audit_log: bool = True):
|
||||
self.audit_log = audit_log
|
||||
self._redaction_count = 0
|
||||
|
||||
def redact(self, text: str) -> Tuple[str, int]:
|
||||
"""
|
||||
Redact credentials from text.
|
||||
|
||||
Returns:
|
||||
Tuple of (redacted_text, number_of_redactions)
|
||||
"""
|
||||
if not text:
|
||||
return text, 0
|
||||
|
||||
redacted = text
|
||||
count = 0
|
||||
|
||||
for pattern, replacement in CREDENTIAL_PATTERNS:
|
||||
matches = re.findall(pattern, redacted, re.IGNORECASE)
|
||||
if matches:
|
||||
redacted = re.sub(pattern, replacement, redacted, flags=re.IGNORECASE)
|
||||
count += len(matches)
|
||||
|
||||
if count > 0:
|
||||
self._redaction_count += count
|
||||
if self.audit_log:
|
||||
self._log_redaction(count, text[:100])
|
||||
|
||||
return redacted, count
|
||||
|
||||
def redact_tool_output(self, tool_name: str, output: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Redact tool output and return notice if redactions occurred.
|
||||
|
||||
Returns:
|
||||
Tuple of (redacted_output, notice_or_empty)
|
||||
"""
|
||||
redacted, count = self.redact(output)
|
||||
|
||||
if count > 0:
|
||||
notice = f"[REDACTED: {count} credential pattern{'s' if count > 1 else ''} found in {tool_name} output]"
|
||||
return redacted, notice
|
||||
|
||||
return redacted, ""
|
||||
|
||||
def should_mask_file(self, file_path: str) -> bool:
|
||||
"""Check if file should have credentials auto-masked."""
|
||||
path_lower = file_path.lower()
|
||||
return any(re.search(p, path_lower) for p in SENSITIVE_FILE_PATTERNS)
|
||||
|
||||
def mask_file_content(self, content: str, file_path: str) -> str:
|
||||
"""Mask credentials in file content while preserving structure."""
|
||||
if not self.should_mask_file(file_path):
|
||||
return content
|
||||
|
||||
lines = content.split("\n")
|
||||
masked_lines = []
|
||||
|
||||
for line in lines:
|
||||
# Preserve key=value structure but mask values
|
||||
if "=" in line and not line.strip().startswith("#"):
|
||||
key, _, value = line.partition("=")
|
||||
key_lower = key.strip().lower()
|
||||
|
||||
sensitive_keys = ["password", "secret", "token", "key", "api", "credential"]
|
||||
if any(sk in key_lower for sk in sensitive_keys):
|
||||
masked_lines.append(f"{key}=[REDACTED]")
|
||||
else:
|
||||
masked_lines.append(line)
|
||||
else:
|
||||
masked_lines.append(line)
|
||||
|
||||
return "\n".join(masked_lines)
|
||||
|
||||
def _log_redaction(self, count: int, preview: str):
|
||||
"""Log redaction event to audit trail."""
|
||||
try:
|
||||
AUDIT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
audit_file = AUDIT_DIR / "redactions.jsonl"
|
||||
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"redactions": count,
|
||||
"preview_hash": hash(preview),
|
||||
}
|
||||
|
||||
with open(audit_file, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Audit log failed: %s", e)
|
||||
|
||||
|
||||
# Module-level redactor
|
||||
_redactor = CredentialRedactor()
|
||||
|
||||
|
||||
def redact_credentials(text: str) -> Tuple[str, int]:
|
||||
"""Redact credentials from text."""
|
||||
return _redactor.redact(text)
|
||||
|
||||
|
||||
def redact_tool_output(tool_name: str, output: str) -> Tuple[str, str]:
|
||||
"""Redact tool output and return notice."""
|
||||
return _redactor.redact_tool_output(tool_name, output)
|
||||
|
||||
|
||||
def should_mask_file(file_path: str) -> bool:
|
||||
"""Check if file should be masked."""
|
||||
return _redactor.should_mask_file(file_path)
|
||||
|
||||
|
||||
def mask_sensitive_file(content: str, file_path: str) -> str:
|
||||
"""Mask credentials in sensitive file."""
|
||||
return _redactor.mask_file_content(content, file_path)
|
||||
@@ -327,6 +327,33 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
# ── Path existence guard (poka-yoke #887) ─────────────────────
|
||||
# Check if file exists before attempting read. 83.7% of read_file
|
||||
# errors are file-not-found — the agent hallucinates paths.
|
||||
# This guard catches them early with a clear, actionable error.
|
||||
if not _resolved.exists():
|
||||
# Try to suggest similar files in the same directory
|
||||
parent = _resolved.parent
|
||||
suggestion = ""
|
||||
if parent.exists() and parent.is_dir():
|
||||
similar = [
|
||||
f.name for f in parent.iterdir()
|
||||
if f.is_file() and _resolved.stem[:3].lower() in f.stem.lower()
|
||||
][:5]
|
||||
if similar:
|
||||
suggestion = f" Similar files in {parent}: {', '.join(similar)}"
|
||||
return json.dumps({
|
||||
"error": (
|
||||
f"File not found: '{path}'. The file does not exist at the resolved path "
|
||||
f"({_resolved}).{suggestion} "
|
||||
"Use search_files to find the correct path first."
|
||||
),
|
||||
"path": path,
|
||||
"resolved": str(_resolved),
|
||||
"suggestion": "Use search_files(pattern='...', target='files') to find files.",
|
||||
})
|
||||
|
||||
# ── Dedup check ───────────────────────────────────────────────
|
||||
# If we already read this exact (path, offset, limit) and the
|
||||
# file hasn't been modified since, return a lightweight stub
|
||||
|
||||
113
tools/hardcoded_path_guard.py
Normal file
113
tools/hardcoded_path_guard.py
Normal file
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Hardcoded Path Guard — Poka-Yoke #921
|
||||
|
||||
Detects and blocks hardcoded home-directory paths in tool arguments.
|
||||
These paths work on one machine but break on others, VPS deployments,
|
||||
or when HOME changes.
|
||||
|
||||
Usage:
|
||||
from tools.hardcoded_path_guard import check_path, validate_tool_args
|
||||
|
||||
# Check a single path
|
||||
err = check_path("/Users/apayne/.hermes/config.yaml")
|
||||
|
||||
# Validate all path-like args in a tool call
|
||||
clean_args, warnings = validate_tool_args("read_file", {"path": "/home/user/file.txt"})
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json as _json
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
|
||||
# Patterns that indicate hardcoded home directories
|
||||
HARDCODED_PATTERNS = [
|
||||
(r"/Users/[\w.\-]+/", "macOS home directory (/Users/...)"),
|
||||
(r"/home/[\w.\-]+/", "Linux home directory (/home/...)"),
|
||||
(r"(?<![\w/])~/", "unexpanded tilde (~/)"),
|
||||
(r"/root/", "root home directory (/root/)"),
|
||||
]
|
||||
|
||||
_COMPILED_PATTERNS = [(re.compile(p), desc) for p, desc in HARDCODED_PATTERNS]
|
||||
_NOQA_PATTERN = re.compile(r"#\s*noqa:?\s*hardcoded-path-ok")
|
||||
|
||||
_PATH_ARG_NAMES = frozenset({
|
||||
"path", "file_path", "filepath", "dir", "directory", "dest", "source",
|
||||
"input", "output", "src", "dst", "target", "location", "file",
|
||||
"image_path", "script", "config", "log_file",
|
||||
})
|
||||
|
||||
|
||||
def has_hardcoded_path(text: str) -> Optional[str]:
|
||||
if _NOQA_PATTERN.search(text):
|
||||
return None
|
||||
for pattern, desc in _COMPILED_PATTERNS:
|
||||
if pattern.search(text):
|
||||
return desc
|
||||
return None
|
||||
|
||||
|
||||
def check_path(path_value: str) -> Optional[str]:
|
||||
if not isinstance(path_value, str):
|
||||
return None
|
||||
match_desc = has_hardcoded_path(path_value)
|
||||
if match_desc:
|
||||
return (
|
||||
f"Path contains hardcoded home directory ({match_desc}): '{path_value}'. "
|
||||
f"Use $HOME, relative paths, or get_hermes_home(). "
|
||||
f"Add '# noqa: hardcoded-path-ok' if intentional."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def validate_tool_args(tool_name: str, args: Dict[str, Any]) -> Tuple[Dict[str, Any], List[str]]:
|
||||
warnings = []
|
||||
for key, value in args.items():
|
||||
if key.lower() not in _PATH_ARG_NAMES:
|
||||
continue
|
||||
if isinstance(value, str):
|
||||
err = check_path(value)
|
||||
if err:
|
||||
warnings.append(err)
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
err = check_path(item)
|
||||
if err:
|
||||
warnings.append(err)
|
||||
return args, warnings
|
||||
|
||||
|
||||
def scan_source_for_violations(source_code: str, filename: str = "") -> List[Tuple[int, str, str]]:
|
||||
violations = []
|
||||
lines = source_code.split("\n")
|
||||
for i, line in enumerate(lines, 1):
|
||||
stripped = line.strip()
|
||||
if stripped.startswith("#"):
|
||||
if _NOQA_PATTERN.search(line):
|
||||
continue
|
||||
continue
|
||||
if stripped.startswith("import ") or stripped.startswith("from "):
|
||||
continue
|
||||
for pattern, desc in _COMPILED_PATTERNS:
|
||||
match = pattern.search(line)
|
||||
if match:
|
||||
if _NOQA_PATTERN.search(line):
|
||||
continue
|
||||
violations.append((i, line.strip(), desc))
|
||||
break
|
||||
return violations
|
||||
|
||||
|
||||
def guard_tool_dispatch(tool_name: str, args: Dict[str, Any]) -> Optional[str]:
|
||||
_, warnings = validate_tool_args(tool_name, args)
|
||||
if warnings:
|
||||
return _json.dumps({
|
||||
"error": "Hardcoded home directory path detected",
|
||||
"details": warnings,
|
||||
"suggestion": "Use $HOME, relative paths, or get_hermes_home() instead of hardcoded paths.",
|
||||
"pokayoke": True,
|
||||
"rule": "hardcoded-path-guard"
|
||||
})
|
||||
return None
|
||||
298
tools/poka_yoke.py
Normal file
298
tools/poka_yoke.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""
|
||||
poka_yoke.py — Validation firewall for tool calls.
|
||||
|
||||
Poka-yoke (mistake-proofing): validates tool calls against the registry
|
||||
before execution. Catches hallucinated tool names, malformed parameters,
|
||||
missing required arguments, and type mismatches.
|
||||
|
||||
Usage:
|
||||
from tools.poka_yoke import validate_tool_call
|
||||
|
||||
is_valid, corrected_name, corrected_params, messages = validate_tool_call(
|
||||
"read_file", {"path": "test.txt"}
|
||||
)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_tool_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]:
|
||||
"""Validate a tool call against the registry before execution.
|
||||
|
||||
Args:
|
||||
function_name: The tool name from the LLM's function_call.
|
||||
function_args: The arguments dict from the LLM's function_call.
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, corrected_name, corrected_params, messages):
|
||||
- is_valid: False if the call should be blocked entirely.
|
||||
- corrected_name: Suggested name if a close match was found (None if OK).
|
||||
- corrected_params: Corrected params if type coercion fixed issues (None if OK).
|
||||
- messages: List of error/warning/info messages.
|
||||
"""
|
||||
from tools.registry import registry
|
||||
|
||||
messages: List[str] = []
|
||||
corrected_name: Optional[str] = None
|
||||
corrected_params: Optional[Dict[str, Any]] = None
|
||||
|
||||
# ── 1. Check if tool exists ───────────────────────────────────────────
|
||||
|
||||
entry = registry.get_entry(function_name)
|
||||
|
||||
if entry is None:
|
||||
# Tool not found — suggest closest match
|
||||
all_names = registry.get_all_tool_names()
|
||||
suggestion = _find_closest_name(function_name, all_names)
|
||||
|
||||
if suggestion:
|
||||
messages.append(
|
||||
f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?"
|
||||
)
|
||||
corrected_name = suggestion
|
||||
# Re-validate with corrected name
|
||||
entry = registry.get_entry(suggestion)
|
||||
if entry is None:
|
||||
return False, corrected_name, None, messages
|
||||
else:
|
||||
available = ", ".join(sorted(all_names)[:20])
|
||||
messages.append(
|
||||
f"Unknown tool '{function_name}'. "
|
||||
f"Available tools: {available}{'...' if len(all_names) > 20 else ''}"
|
||||
)
|
||||
return False, None, None, messages
|
||||
|
||||
# ── 2. Validate parameters against schema ─────────────────────────────
|
||||
|
||||
schema = entry.schema
|
||||
params_schema = schema.get("parameters", {})
|
||||
properties = params_schema.get("properties", {})
|
||||
required = set(params_schema.get("required", []))
|
||||
|
||||
# Check for missing required parameters
|
||||
for param_name in sorted(required):
|
||||
if param_name not in function_args:
|
||||
param_info = properties.get(param_name, {})
|
||||
param_type = param_info.get("type", "any")
|
||||
messages.append(
|
||||
f"Missing required parameter '{param_name}' "
|
||||
f"(expected type: {param_type}). "
|
||||
f"Tool: {function_name}"
|
||||
)
|
||||
|
||||
# If required params are missing, we still return the error
|
||||
# (the agent might be able to self-correct)
|
||||
if any("Missing required" in m for m in messages):
|
||||
# Don't block — return the error as a tool result so the agent can retry
|
||||
# But mark as invalid so caller knows
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
# ── 3. Check for unknown parameters ───────────────────────────────────
|
||||
|
||||
if properties:
|
||||
known_params = set(properties.keys())
|
||||
# Allow extra params that start with _ (internal convention)
|
||||
unknown = [
|
||||
p for p in function_args
|
||||
if p not in known_params and not p.startswith("_")
|
||||
]
|
||||
if unknown:
|
||||
known_str = ", ".join(sorted(known_params))
|
||||
unknown_str = ", ".join(sorted(unknown))
|
||||
messages.append(
|
||||
f"Unknown parameter(s) for '{function_name}': {unknown_str}. "
|
||||
f"Known parameters: {known_str}"
|
||||
)
|
||||
# Remove unknown params (don't block, just clean)
|
||||
corrected_params = {
|
||||
k: v for k, v in function_args.items()
|
||||
if k in known_params or k.startswith("_")
|
||||
}
|
||||
|
||||
# ── 4. Type validation ────────────────────────────────────────────────
|
||||
|
||||
type_errors = []
|
||||
coerced = dict(corrected_params or function_args)
|
||||
|
||||
for param_name, param_value in coerced.items():
|
||||
if param_name.startswith("_"):
|
||||
continue
|
||||
param_schema = properties.get(param_name)
|
||||
if not param_schema:
|
||||
continue
|
||||
|
||||
expected_type = param_schema.get("type")
|
||||
if not expected_type:
|
||||
continue
|
||||
|
||||
is_valid_type, coerced_value = _validate_type(
|
||||
param_name, param_value, expected_type
|
||||
)
|
||||
if not is_valid_type:
|
||||
type_errors.append(
|
||||
f"Parameter '{param_name}': expected {expected_type}, "
|
||||
f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})"
|
||||
)
|
||||
elif coerced_value is not param_value:
|
||||
coerced[param_name] = coerced_value
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': coerced from "
|
||||
f"{type(param_value).__name__} to {expected_type}"
|
||||
)
|
||||
|
||||
if type_errors:
|
||||
messages.extend(type_errors)
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
if coerced != (corrected_params or function_args):
|
||||
corrected_params = coerced
|
||||
|
||||
# ── 5. Enum validation ────────────────────────────────────────────────
|
||||
|
||||
for param_name, param_value in (corrected_params or function_args).items():
|
||||
param_schema = properties.get(param_name, {})
|
||||
enum_values = param_schema.get("enum")
|
||||
if enum_values and param_value not in enum_values:
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': value '{param_value}' not in "
|
||||
f"allowed values: {enum_values}"
|
||||
)
|
||||
return False, corrected_name, corrected_params, messages
|
||||
|
||||
# ── 6. Pattern validation ─────────────────────────────────────────────
|
||||
|
||||
for param_name, param_value in (corrected_params or function_args).items():
|
||||
if not isinstance(param_value, str):
|
||||
continue
|
||||
param_schema = properties.get(param_name, {})
|
||||
pattern = param_schema.get("pattern")
|
||||
if pattern and not re.match(pattern, param_value):
|
||||
messages.append(
|
||||
f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' "
|
||||
f"does not match pattern '{pattern}'"
|
||||
)
|
||||
|
||||
# ── Done ──────────────────────────────────────────────────────────────
|
||||
|
||||
is_valid = not any("Missing required" in m for m in messages)
|
||||
|
||||
if is_valid and not messages:
|
||||
return True, None, None, []
|
||||
|
||||
return is_valid, corrected_name, corrected_params, messages
|
||||
|
||||
|
||||
def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]:
|
||||
"""Find the closest tool name using simple edit distance heuristics."""
|
||||
if not candidates:
|
||||
return None
|
||||
|
||||
target_lower = target.lower()
|
||||
|
||||
# Exact prefix match
|
||||
for name in candidates:
|
||||
if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3:
|
||||
return name
|
||||
|
||||
# Substring match
|
||||
for name in candidates:
|
||||
if target_lower in name.lower() or name.lower() in target_lower:
|
||||
return name
|
||||
|
||||
# Levenshtein distance (simple, for short strings)
|
||||
def _levenshtein(a: str, b: str) -> int:
|
||||
if len(a) < len(b):
|
||||
return _levenshtein(b, a)
|
||||
if len(b) == 0:
|
||||
return len(a)
|
||||
prev = range(len(b) + 1)
|
||||
for i, ca in enumerate(a):
|
||||
curr = [i + 1]
|
||||
for j, cb in enumerate(b):
|
||||
curr.append(min(
|
||||
prev[j + 1] + 1,
|
||||
curr[j] + 1,
|
||||
prev[j] + (0 if ca == cb else 1),
|
||||
))
|
||||
prev = curr
|
||||
return prev[-1]
|
||||
|
||||
distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates]
|
||||
distances.sort(key=lambda x: x[1])
|
||||
|
||||
# Return if edit distance is small enough
|
||||
if distances and distances[0][1] <= max(3, len(target) // 3):
|
||||
return distances[0][0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _validate_type(
|
||||
param_name: str, value: Any, expected_type: str
|
||||
) -> Tuple[bool, Any]:
|
||||
"""Validate and optionally coerce a parameter value to the expected type.
|
||||
|
||||
Returns (is_valid, coerced_value). coerced_value is value itself if no
|
||||
coercion was needed.
|
||||
"""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": (int, float),
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
expected = type_map.get(expected_type)
|
||||
if expected is None:
|
||||
return True, value # Unknown type, skip validation
|
||||
|
||||
# Direct type check
|
||||
if isinstance(value, expected):
|
||||
return True, value
|
||||
|
||||
# Coercion attempts
|
||||
if expected_type == "string":
|
||||
return True, str(value)
|
||||
|
||||
if expected_type == "integer":
|
||||
if isinstance(value, str) and value.isdigit():
|
||||
return True, int(value)
|
||||
if isinstance(value, float) and value == int(value):
|
||||
return True, int(value)
|
||||
return False, value
|
||||
|
||||
if expected_type == "number":
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return True, float(value)
|
||||
except ValueError:
|
||||
return False, value
|
||||
return False, value
|
||||
|
||||
if expected_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
lower = value.lower()
|
||||
if lower in ("true", "1", "yes"):
|
||||
return True, True
|
||||
if lower in ("false", "0", "no"):
|
||||
return True, False
|
||||
return False, value
|
||||
|
||||
return False, value
|
||||
|
||||
|
||||
def _truncate(s: str, max_len: int) -> str:
|
||||
"""Truncate a string for display."""
|
||||
if len(s) <= max_len:
|
||||
return s
|
||||
return s[:max_len - 3] + "..."
|
||||
275
tools/session_templates.py
Normal file
275
tools/session_templates.py
Normal file
@@ -0,0 +1,275 @@
|
||||
"""
|
||||
Session templates for code-first seeding.
|
||||
|
||||
Research: Code-heavy sessions (execute_code dominant in first 30 turns) improve over time.
|
||||
File-heavy sessions degrade. Key is deterministic feedback loops.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sqlite3
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass, asdict, field
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TEMPLATE_DIR = Path.home() / ".hermes" / "session-templates"
|
||||
|
||||
|
||||
class TaskType(Enum):
|
||||
CODE = "code"
|
||||
FILE = "file"
|
||||
RESEARCH = "research"
|
||||
MIXED = "mixed"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolExample:
|
||||
tool_name: str
|
||||
arguments: Dict[str, Any]
|
||||
result: str
|
||||
success: bool
|
||||
turn: int = 0
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
return cls(**data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
name: str
|
||||
task_type: TaskType
|
||||
examples: List[ToolExample]
|
||||
desc: str = ""
|
||||
created: float = 0.0
|
||||
used: int = 0
|
||||
session_id: Optional[str] = None
|
||||
tags: List[str] = field(default_factory=list)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created == 0.0:
|
||||
self.created = time.time()
|
||||
|
||||
def to_dict(self):
|
||||
d = asdict(self)
|
||||
d['task_type'] = self.task_type.value
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data):
|
||||
data['task_type'] = TaskType(data['task_type'])
|
||||
data['examples'] = [ToolExample.from_dict(e) for e in data.get('examples', [])]
|
||||
return cls(**data)
|
||||
|
||||
|
||||
class Templates:
|
||||
def __init__(self, dir=None):
|
||||
self.dir = dir or TEMPLATE_DIR
|
||||
self.dir.mkdir(parents=True, exist_ok=True)
|
||||
self.templates = {}
|
||||
self._load()
|
||||
|
||||
def _load(self):
|
||||
for f in self.dir.glob("*.json"):
|
||||
try:
|
||||
with open(f) as fh:
|
||||
t = Template.from_dict(json.load(fh))
|
||||
self.templates[t.name] = t
|
||||
except Exception as e:
|
||||
logger.warning(f"Load failed {f}: {e}")
|
||||
|
||||
def _save(self, t):
|
||||
with open(self.dir / f"{t.name}.json", 'w') as f:
|
||||
json.dump(t.to_dict(), f, indent=2)
|
||||
|
||||
def classify(self, calls):
|
||||
if not calls:
|
||||
return TaskType.MIXED
|
||||
code = {'execute_code', 'code_execution'}
|
||||
file_ops = {'read_file', 'write_file', 'patch', 'search_files'}
|
||||
research = {'web_search', 'web_fetch', 'browser_navigate'}
|
||||
names = [c.get('tool_name', '') for c in calls]
|
||||
total = len(names)
|
||||
if sum(1 for n in names if n in code) / total > 0.6:
|
||||
return TaskType.CODE
|
||||
if sum(1 for n in names if n in file_ops) / total > 0.6:
|
||||
return TaskType.FILE
|
||||
if sum(1 for n in names if n in research) / total > 0.6:
|
||||
return TaskType.RESEARCH
|
||||
return TaskType.MIXED
|
||||
|
||||
def extract(self, session_id, max_n=10):
|
||||
db = Path.home() / ".hermes" / "state.db"
|
||||
if not db.exists():
|
||||
return []
|
||||
try:
|
||||
conn = sqlite3.connect(str(db))
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT role, content, tool_calls FROM messages WHERE session_id=? ORDER BY timestamp LIMIT 100",
|
||||
(session_id,)
|
||||
).fetchall()
|
||||
conn.close()
|
||||
examples = []
|
||||
turn = 0
|
||||
for r in rows:
|
||||
if len(examples) >= max_n:
|
||||
break
|
||||
if r['role'] == 'assistant' and r['tool_calls']:
|
||||
try:
|
||||
for tc in json.loads(r['tool_calls']):
|
||||
if len(examples) >= max_n:
|
||||
break
|
||||
name = tc.get('function', {}).get('name')
|
||||
if not name:
|
||||
continue
|
||||
try:
|
||||
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
|
||||
except:
|
||||
args = {}
|
||||
examples.append(ToolExample(name, args, "", True, turn))
|
||||
turn += 1
|
||||
except:
|
||||
continue
|
||||
elif r['role'] == 'tool' and examples and examples[-1].result == "":
|
||||
examples[-1].result = r['content'] or ""
|
||||
return examples
|
||||
except Exception as e:
|
||||
logger.error(f"Extract failed: {e}")
|
||||
return []
|
||||
|
||||
def create(self, session_id, name=None, task_type=None, max_n=10, desc="", tags=None):
|
||||
examples = self.extract(session_id, max_n)
|
||||
if not examples:
|
||||
return None
|
||||
if task_type is None:
|
||||
task_type = self.classify([{'tool_name': e.tool_name} for e in examples])
|
||||
if name is None:
|
||||
name = f"{task_type.value}_{session_id[:8]}_{int(time.time())}"
|
||||
t = Template(name, task_type, examples, desc or f"{len(examples)} examples", time.time(), 0, session_id, tags or [])
|
||||
self.templates[name] = t
|
||||
self._save(t)
|
||||
logger.info(f"Created {name} with {len(examples)} examples")
|
||||
return t
|
||||
|
||||
def get(self, task_type, tags=None):
|
||||
matching = [t for t in self.templates.values() if t.task_type == task_type]
|
||||
if tags:
|
||||
matching = [t for t in matching if any(tag in t.tags for tag in tags)]
|
||||
if not matching:
|
||||
return None
|
||||
matching.sort(key=lambda t: t.used)
|
||||
return matching[0]
|
||||
|
||||
def inject(self, template, messages):
|
||||
if not template.examples:
|
||||
return messages
|
||||
injection = [{
|
||||
"role": "system",
|
||||
"content": f"Template: {template.name} ({template.task_type.value})\n{template.desc}"
|
||||
}]
|
||||
for i, ex in enumerate(template.examples):
|
||||
injection.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [{
|
||||
"id": f"tpl_{i}",
|
||||
"type": "function",
|
||||
"function": {"name": ex.tool_name, "arguments": json.dumps(ex.arguments)}
|
||||
}]
|
||||
})
|
||||
injection.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": f"tpl_{i}",
|
||||
"content": ex.result
|
||||
})
|
||||
idx = 0
|
||||
for i, m in enumerate(messages):
|
||||
if m.get("role") != "system":
|
||||
break
|
||||
idx = i + 1
|
||||
for i, m in enumerate(injection):
|
||||
messages.insert(idx + i, m)
|
||||
template.used += 1
|
||||
self._save(template)
|
||||
return messages
|
||||
|
||||
def list(self, task_type=None, tags=None):
|
||||
ts = list(self.templates.values())
|
||||
if task_type:
|
||||
ts = [t for t in ts if t.task_type == task_type]
|
||||
if tags:
|
||||
ts = [t for t in ts if any(tag in t.tags for tag in tags)]
|
||||
ts.sort(key=lambda t: t.created, reverse=True)
|
||||
return ts
|
||||
|
||||
def delete(self, name):
|
||||
if name not in self.templates:
|
||||
return False
|
||||
del self.templates[name]
|
||||
p = self.dir / f"{name}.json"
|
||||
if p.exists():
|
||||
p.unlink()
|
||||
return True
|
||||
|
||||
def stats(self):
|
||||
if not self.templates:
|
||||
return {"total": 0, "by_type": {}, "examples": 0, "usage": 0}
|
||||
by_type = {}
|
||||
total_ex = 0
|
||||
total_use = 0
|
||||
for t in self.templates.values():
|
||||
by_type[t.task_type.value] = by_type.get(t.task_type.value, 0) + 1
|
||||
total_ex += len(t.examples)
|
||||
total_use += t.used
|
||||
return {"total": len(self.templates), "by_type": by_type, "examples": total_ex, "usage": total_use}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
p = argparse.ArgumentParser()
|
||||
s = p.add_subparsers(dest="cmd")
|
||||
lp = s.add_parser("list")
|
||||
lp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||
lp.add_argument("--tags")
|
||||
cp = s.add_parser("create")
|
||||
cp.add_argument("session_id")
|
||||
cp.add_argument("--name")
|
||||
cp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||
cp.add_argument("--max", type=int, default=10)
|
||||
cp.add_argument("--desc")
|
||||
cp.add_argument("--tags")
|
||||
dp = s.add_parser("delete")
|
||||
dp.add_argument("name")
|
||||
sp = s.add_parser("stats")
|
||||
args = p.parse_args()
|
||||
ts = Templates()
|
||||
if args.cmd == "list":
|
||||
tt = TaskType(args.type) if args.type else None
|
||||
tags = args.tags.split(",") if args.tags else None
|
||||
for t in ts.list(tt, tags):
|
||||
print(f"{t.name}: {t.task_type.value} ({len(t.examples)} ex, used {t.used}x)")
|
||||
elif args.cmd == "create":
|
||||
tt = TaskType(args.type) if args.type else None
|
||||
tags = args.tags.split(",") if args.tags else None
|
||||
t = ts.create(args.session_id, args.name, tt, args.max, args.desc or "", tags)
|
||||
if t:
|
||||
print(f"Created: {t.name} ({len(t.examples)} examples)")
|
||||
else:
|
||||
print("Failed")
|
||||
elif args.cmd == "delete":
|
||||
print("Deleted" if ts.delete(args.name) else "Not found")
|
||||
elif args.cmd == "stats":
|
||||
s = ts.stats()
|
||||
print(f"Total: {s['total']}, Examples: {s['examples']}, Usage: {s['usage']}")
|
||||
for k, v in s['by_type'].items():
|
||||
print(f" {k}: {v}")
|
||||
else:
|
||||
p.print_help()
|
||||
@@ -38,12 +38,41 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_error(
|
||||
message: str,
|
||||
skill_name: str = None,
|
||||
file_path: str = None,
|
||||
suggestion: str = None,
|
||||
context: dict = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Format an error with rich context for better debugging."""
|
||||
parts = [message]
|
||||
if skill_name:
|
||||
parts.append(f"Skill: {skill_name}")
|
||||
if file_path:
|
||||
parts.append(f"File: {file_path}")
|
||||
if suggestion:
|
||||
parts.append(f"Suggestion: {suggestion}")
|
||||
if context:
|
||||
for key, value in context.items():
|
||||
parts.append(f"{key}: {value}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": " | ".join(parts),
|
||||
"skill_name": skill_name,
|
||||
"file_path": file_path,
|
||||
"suggestion": suggestion,
|
||||
}
|
||||
|
||||
|
||||
# Import security scanner — agent-created skills get the same scrutiny as
|
||||
# community hub installs.
|
||||
try:
|
||||
@@ -253,6 +282,94 @@ def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Pat
|
||||
return target, None
|
||||
|
||||
|
||||
MAX_BACKUPS_PER_FILE = 3
|
||||
|
||||
|
||||
def _backup_skill_file(file_path: Path) -> Optional[Path]:
|
||||
"""Create a timestamped backup of a skill file before modification.
|
||||
|
||||
The backup is named ``{original_name}.bak.{unix_timestamp}`` and placed
|
||||
in the same directory. Returns the backup path, or *None* if the file
|
||||
does not exist yet (nothing to back up).
|
||||
"""
|
||||
if not file_path.exists():
|
||||
return None
|
||||
timestamp = int(time.time())
|
||||
backup_path = file_path.parent / f"{file_path.name}.bak.{timestamp}"
|
||||
shutil.copy2(str(file_path), str(backup_path))
|
||||
return backup_path
|
||||
|
||||
|
||||
def _cleanup_old_backups(file_path: Path, max_backups: int = MAX_BACKUPS_PER_FILE) -> None:
|
||||
"""Prune backup files so at most *max_backups* are retained.
|
||||
|
||||
Backups match the pattern ``{file_path.name}.bak.*`` in the same
|
||||
directory. The oldest (by mtime) are removed first.
|
||||
"""
|
||||
parent = file_path.parent
|
||||
prefix = file_path.name + ".bak."
|
||||
try:
|
||||
backups: List[Path] = sorted(
|
||||
[f for f in parent.iterdir() if f.name.startswith(prefix) and f.is_file()],
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
)
|
||||
except OSError:
|
||||
return
|
||||
while len(backups) > max_backups:
|
||||
try:
|
||||
backups.pop(0).unlink()
|
||||
except OSError:
|
||||
break
|
||||
|
||||
|
||||
def _validate_written_file(file_path: Path, is_skill_md: bool = False) -> Optional[str]:
|
||||
"""Re-read a file from disk and validate it after writing.
|
||||
|
||||
Catches filesystem-level issues (truncation, encoding errors, empty
|
||||
writes) that pre-write validation cannot detect. For SKILL.md files
|
||||
the frontmatter is also re-validated.
|
||||
|
||||
Returns an error message, or *None* if the file looks healthy.
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
return f"Failed to read file after write: {exc}"
|
||||
except UnicodeDecodeError as exc:
|
||||
return f"File encoding error after write: {exc}"
|
||||
|
||||
if len(content) == 0:
|
||||
return "File is empty after write (possible truncation)."
|
||||
|
||||
if is_skill_md:
|
||||
err = _validate_frontmatter(content)
|
||||
if err:
|
||||
return f"Post-write validation failed: {err}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _revert_from_backup(file_path: Path, backup_path: Optional[Path]) -> None:
|
||||
"""Restore *file_path* from *backup_path*.
|
||||
|
||||
If *backup_path* is None or missing the target file is removed so the
|
||||
skill directory is at least not left with corrupted content.
|
||||
"""
|
||||
if backup_path and backup_path.exists():
|
||||
try:
|
||||
shutil.copy2(str(backup_path), str(file_path))
|
||||
except OSError:
|
||||
logger.error(
|
||||
"Failed to restore %s from backup %s", file_path, backup_path, exc_info=True
|
||||
)
|
||||
else:
|
||||
# No backup — remove the partially-written file
|
||||
try:
|
||||
file_path.unlink(missing_ok=True)
|
||||
except OSError:
|
||||
logger.error("Failed to remove corrupted file %s after failed write", file_path, exc_info=True)
|
||||
|
||||
|
||||
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||
"""
|
||||
Atomically write text content to a file.
|
||||
@@ -358,20 +475,35 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
|
||||
return _format_error(
|
||||
f"Skill '{name}' not found.",
|
||||
skill_name=name,
|
||||
suggestion="Use skills_list() to see available skills.",
|
||||
)
|
||||
|
||||
skill_md = existing["path"] / "SKILL.md"
|
||||
# Back up original content for rollback
|
||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(skill_md)
|
||||
_atomic_write_text(skill_md, content)
|
||||
|
||||
# Post-write validation: catch filesystem-level failures
|
||||
validate_err = _validate_written_file(skill_md, is_skill_md=True)
|
||||
if validate_err:
|
||||
_revert_from_backup(skill_md, backup_path)
|
||||
return {"success": False, "error": f"Edit reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(skill_md, original_content)
|
||||
_revert_from_backup(skill_md, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(skill_md)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Skill '{name}' updated.",
|
||||
@@ -392,13 +524,25 @@ def _patch_skill(
|
||||
Requires a unique match unless replace_all is True.
|
||||
"""
|
||||
if not old_string:
|
||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
||||
return _format_error(
|
||||
"old_string is required for 'patch'.",
|
||||
skill_name=name,
|
||||
suggestion="Provide the exact text to find in the skill file.",
|
||||
)
|
||||
if new_string is None:
|
||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
||||
return _format_error(
|
||||
"new_string is required for 'patch'. Use an empty string to delete matched text.",
|
||||
skill_name=name,
|
||||
suggestion="Pass new_string='' to delete the matched text.",
|
||||
)
|
||||
|
||||
existing = _find_skill(name)
|
||||
if not existing:
|
||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
||||
return _format_error(
|
||||
f"Skill '{name}' not found.",
|
||||
skill_name=name,
|
||||
suggestion="Use skills_list() to see available skills.",
|
||||
)
|
||||
|
||||
skill_dir = existing["path"]
|
||||
|
||||
@@ -452,15 +596,29 @@ def _patch_skill(
|
||||
"error": f"Patch would break SKILL.md structure: {err}",
|
||||
}
|
||||
|
||||
original_content = content # for rollback
|
||||
is_skill_md = not file_path
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(target)
|
||||
_atomic_write_text(target, new_content)
|
||||
|
||||
# Post-write validation
|
||||
validate_err = _validate_written_file(target, is_skill_md=is_skill_md)
|
||||
if validate_err:
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": f"Patch reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(skill_dir)
|
||||
if scan_error:
|
||||
_atomic_write_text(target, original_content)
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(target)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({match_count} replacement{'s' if match_count > 1 else ''}).",
|
||||
@@ -519,19 +677,28 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
||||
if err:
|
||||
return {"success": False, "error": err}
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Back up for rollback
|
||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
||||
|
||||
# --- Transactional write-validate-commit-or-rollback ---
|
||||
backup_path = _backup_skill_file(target)
|
||||
_atomic_write_text(target, file_content)
|
||||
|
||||
# Post-write validation: ensure the file is readable and non-empty
|
||||
validate_err = _validate_written_file(target, is_skill_md=False)
|
||||
if validate_err:
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": f"Write reverted: {validate_err}"}
|
||||
|
||||
# Security scan — roll back on block
|
||||
scan_error = _security_scan_skill(existing["path"])
|
||||
if scan_error:
|
||||
if original_content is not None:
|
||||
_atomic_write_text(target, original_content)
|
||||
else:
|
||||
target.unlink(missing_ok=True)
|
||||
_revert_from_backup(target, backup_path)
|
||||
return {"success": False, "error": scan_error}
|
||||
|
||||
# Success — remove the backup we just created, prune any older ones
|
||||
if backup_path:
|
||||
backup_path.unlink(missing_ok=True)
|
||||
_cleanup_old_backups(target)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"File '{file_path}' written to skill '{name}'.",
|
||||
|
||||
109
tools/static_analyzer.py
Normal file
109
tools/static_analyzer.py
Normal file
@@ -0,0 +1,109 @@
|
||||
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
GOFAI Static Analyzer — Deterministic risk assessment for autonomous code.
|
||||
|
||||
Detects high-risk patterns like infinite loops, resource exhaustion,
|
||||
and circular dependencies using AST analysis.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any
|
||||
from tools.registry import registry, tool_error, tool_result
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
STATIC_ANALYZE_SCHEMA = {
|
||||
"name": "static_analyze",
|
||||
"description": "Perform an advanced GOFAI static analysis of code. Detects infinite loops, potential memory leaks (unbounded collections), and circular dependency risks without using an LLM. Use this to ensure your code is 'Fleet-Safe'.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the file to analyze."}
|
||||
},
|
||||
"required": ["path"]
|
||||
}
|
||||
}
|
||||
|
||||
class RiskAnalyzer(ast.NodeVisitor):
|
||||
def __init__(self):
|
||||
self.risks = []
|
||||
self.current_function = None
|
||||
|
||||
def visit_FunctionDef(self, node):
|
||||
old_func = self.current_function
|
||||
self.current_function = node.name
|
||||
self.generic_visit(node)
|
||||
self.current_function = old_func
|
||||
|
||||
def visit_While(self, node):
|
||||
# Check for 'while True' or 'while 1'
|
||||
if isinstance(node.test, ast.Constant) and node.test.value is True:
|
||||
# Look for 'break' or 'return' inside the loop
|
||||
has_exit = any(isinstance(child, (ast.Break, ast.Return)) for child in ast.walk(node))
|
||||
if not has_exit:
|
||||
self.risks.append({
|
||||
"type": "Infinite Loop Risk",
|
||||
"location": f"{self.current_function or 'module'} (line {node.lineno})",
|
||||
"severity": "HIGH",
|
||||
"message": "Potential infinite loop: 'while True' found without clear break/return path."
|
||||
})
|
||||
self.generic_visit(node)
|
||||
|
||||
def visit_For(self, node):
|
||||
# Basic check for modifying the sequence being iterated (common error)
|
||||
if isinstance(node.target, ast.Name):
|
||||
for child in ast.walk(node.body):
|
||||
if isinstance(child, ast.Call) and isinstance(child.func, ast.Attribute):
|
||||
if child.func.attr in ['append', 'extend', 'pop', 'remove']:
|
||||
if isinstance(child.func.value, ast.Name) and child.func.value.id == node.target.id:
|
||||
self.risks.append({
|
||||
"type": "Mutation Risk",
|
||||
"location": f"{self.current_function or 'module'} (line {node.lineno})",
|
||||
"severity": "MEDIUM",
|
||||
"message": f"Loop modifies iterator variable '{node.target.id}'."
|
||||
})
|
||||
self.generic_visit(node)
|
||||
|
||||
def run_analysis(path: str):
|
||||
"""Run the static analysis pipeline."""
|
||||
try:
|
||||
source = open(path, "r").read()
|
||||
tree = ast.parse(source)
|
||||
|
||||
analyzer = RiskAnalyzer()
|
||||
analyzer.visit(tree)
|
||||
|
||||
if not analyzer.risks:
|
||||
return tool_result(
|
||||
status="Verified Safe",
|
||||
message="No high-risk GOFAI patterns detected. Code appears compliant with Fleet execution safety standards."
|
||||
)
|
||||
|
||||
summary = "GOFAI RISK ASSESSMENT REPORT:\n"
|
||||
for risk in analyzer.risks:
|
||||
summary += f"- [{risk['severity']}] {risk['type']} in {risk['location']}: {risk['message']}\n"
|
||||
|
||||
return tool_result(
|
||||
status="Risk Detected",
|
||||
summary=summary,
|
||||
risks=analyzer.risks,
|
||||
recommendation="Address the identified risks before deploying this code to the fleet."
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return tool_error(f"Static analysis failed: {str(e)}")
|
||||
|
||||
def _handle_static_analyze(args, **kwargs):
|
||||
return run_analysis(args.get("path"))
|
||||
|
||||
registry.register(
|
||||
name="static_analyze",
|
||||
toolset="qa",
|
||||
schema=STATIC_ANALYZE_SCHEMA,
|
||||
handler=_handle_static_analyze,
|
||||
emoji="🛡️"
|
||||
)
|
||||
|
||||
312
tools/tool_validator.py
Normal file
312
tools/tool_validator.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""
|
||||
Poka-Yoke: Tool Hallucination Detection — #922.
|
||||
|
||||
Validation firewall between LLM tool-call output and actual execution.
|
||||
|
||||
Detects and blocks:
|
||||
1. Unknown tool names (hallucinated tools)
|
||||
2. Malformed parameters (wrong types)
|
||||
3. Missing required arguments
|
||||
4. Extra unknown parameters
|
||||
|
||||
Poka-Yoke Type: Detection (catches errors at the boundary before harm)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ValidationSeverity(Enum):
|
||||
"""Severity of validation failure."""
|
||||
BLOCK = "block" # Must block execution
|
||||
WARN = "warn" # Warning, may proceed
|
||||
INFO = "info" # Informational
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationIssue:
|
||||
"""A validation issue found."""
|
||||
severity: ValidationSeverity
|
||||
code: str
|
||||
message: str
|
||||
tool_name: str
|
||||
parameter: Optional[str] = None
|
||||
expected: Optional[str] = None
|
||||
actual: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of tool call validation."""
|
||||
valid: bool
|
||||
tool_name: str
|
||||
issues: List[ValidationIssue] = field(default_factory=list)
|
||||
corrected_args: Optional[Dict[str, Any]] = None
|
||||
|
||||
@property
|
||||
def blocking_issues(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK]
|
||||
|
||||
@property
|
||||
def warnings(self) -> List[ValidationIssue]:
|
||||
return [i for i in self.issues if i.severity == ValidationSeverity.WARN]
|
||||
|
||||
|
||||
class ToolHallucinationDetector:
|
||||
"""
|
||||
Poka-yoke detector for tool hallucinations.
|
||||
|
||||
Validates tool calls against registered schemas before execution.
|
||||
"""
|
||||
|
||||
def __init__(self, tool_registry: Optional[Dict] = None):
|
||||
"""
|
||||
Initialize detector.
|
||||
|
||||
Args:
|
||||
tool_registry: Dict of tool_name -> tool_schema
|
||||
"""
|
||||
self.registry = tool_registry or {}
|
||||
self._rejection_log: List[Dict] = []
|
||||
|
||||
def register_tool(self, name: str, schema: Dict):
|
||||
"""Register a tool with its JSON Schema."""
|
||||
self.registry[name] = schema
|
||||
|
||||
def register_tools(self, tools: Dict[str, Dict]):
|
||||
"""Register multiple tools."""
|
||||
self.registry.update(tools)
|
||||
|
||||
def validate_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
arguments: Dict[str, Any],
|
||||
model: str = "unknown",
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
Validate a tool call against the registry.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool being called
|
||||
arguments: Arguments passed to the tool
|
||||
model: Model that generated the call (for logging)
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status
|
||||
"""
|
||||
issues = []
|
||||
|
||||
# 1. Check if tool exists
|
||||
if tool_name not in self.registry:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="UNKNOWN_TOOL",
|
||||
message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...",
|
||||
tool_name=tool_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL")
|
||||
return ValidationResult(valid=False, tool_name=tool_name, issues=issues)
|
||||
|
||||
schema = self.registry[tool_name]
|
||||
params_schema = schema.get("parameters", {}).get("properties", {})
|
||||
required = set(schema.get("parameters", {}).get("required", []))
|
||||
|
||||
# 2. Check for missing required parameters
|
||||
for param in required:
|
||||
if param not in arguments:
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="MISSING_REQUIRED",
|
||||
message=f"Missing required parameter: {param}",
|
||||
tool_name=tool_name,
|
||||
parameter=param,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 3. Check parameter types
|
||||
for param_name, param_value in arguments.items():
|
||||
if param_name not in params_schema:
|
||||
# Unknown parameter
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="UNKNOWN_PARAM",
|
||||
message=f"Unknown parameter: {param_name}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
)
|
||||
issues.append(issue)
|
||||
continue
|
||||
|
||||
param_schema = params_schema[param_name]
|
||||
expected_type = param_schema.get("type")
|
||||
|
||||
if expected_type and not self._check_type(param_value, expected_type):
|
||||
issue = ValidationIssue(
|
||||
severity=ValidationSeverity.BLOCK,
|
||||
code="WRONG_TYPE",
|
||||
message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
expected=expected_type,
|
||||
actual=type(param_value).__name__,
|
||||
)
|
||||
issues.append(issue)
|
||||
|
||||
# 4. Check for common hallucination patterns
|
||||
hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments)
|
||||
issues.extend(hallucination_issues)
|
||||
|
||||
# Determine validity
|
||||
has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues)
|
||||
|
||||
if has_blocking:
|
||||
self._log_rejection(tool_name, arguments, model,
|
||||
"; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK))
|
||||
|
||||
return ValidationResult(
|
||||
valid=not has_blocking,
|
||||
tool_name=tool_name,
|
||||
issues=issues,
|
||||
)
|
||||
|
||||
def _check_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""Check if value matches expected JSON Schema type."""
|
||||
type_map = {
|
||||
"string": str,
|
||||
"number": (int, float),
|
||||
"integer": int,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict,
|
||||
}
|
||||
|
||||
expected = type_map.get(expected_type)
|
||||
if expected is None:
|
||||
return True # Unknown type, assume OK
|
||||
|
||||
return isinstance(value, expected)
|
||||
|
||||
def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]:
|
||||
"""Detect common hallucination patterns."""
|
||||
issues = []
|
||||
|
||||
# Pattern 1: Placeholder values
|
||||
placeholder_patterns = [
|
||||
r"^<.*>$", # <placeholder>
|
||||
r"^\[.*\]$", # [placeholder]
|
||||
r"^TODO$|^FIXME$", # TODO/FIXME
|
||||
r"^example\.com$", # example.com
|
||||
r"^127\.0\.0\.1$", # localhost
|
||||
]
|
||||
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str):
|
||||
for pattern in placeholder_patterns:
|
||||
if re.match(pattern, param_value, re.IGNORECASE):
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="PLACEHOLDER_VALUE",
|
||||
message=f"Parameter '{param_name}' contains placeholder: {param_value}",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
# Pattern 2: Suspiciously long strings (might be hallucinated content)
|
||||
for param_name, param_value in arguments.items():
|
||||
if isinstance(param_value, str) and len(param_value) > 10000:
|
||||
issues.append(ValidationIssue(
|
||||
severity=ValidationSeverity.WARN,
|
||||
code="SUSPICIOUS_LENGTH",
|
||||
message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)",
|
||||
tool_name=tool_name,
|
||||
parameter=param_name,
|
||||
))
|
||||
|
||||
return issues
|
||||
|
||||
def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str):
|
||||
"""Log a rejected tool call for analysis."""
|
||||
import time
|
||||
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"tool_name": tool_name,
|
||||
"arguments": {k: str(v)[:100] for k, v in arguments.items()},
|
||||
"model": model,
|
||||
"reason": reason,
|
||||
}
|
||||
|
||||
self._rejection_log.append(entry)
|
||||
|
||||
# Keep log bounded
|
||||
if len(self._rejection_log) > 1000:
|
||||
self._rejection_log = self._rejection_log[-500:]
|
||||
|
||||
logger.warning(
|
||||
"Tool hallucination blocked: tool=%s, model=%s, reason=%s",
|
||||
tool_name, model, reason
|
||||
)
|
||||
|
||||
def get_rejection_stats(self) -> Dict:
|
||||
"""Get statistics on rejected tool calls."""
|
||||
if not self._rejection_log:
|
||||
return {"total": 0, "by_reason": {}, "by_tool": {}}
|
||||
|
||||
by_reason = {}
|
||||
by_tool = {}
|
||||
|
||||
for entry in self._rejection_log:
|
||||
reason = entry["reason"]
|
||||
tool = entry["tool_name"]
|
||||
|
||||
by_reason[reason] = by_reason.get(reason, 0) + 1
|
||||
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||
|
||||
return {
|
||||
"total": len(self._rejection_log),
|
||||
"by_reason": by_reason,
|
||||
"by_tool": by_tool,
|
||||
}
|
||||
|
||||
def format_validation_report(self, result: ValidationResult) -> str:
|
||||
"""Format validation result as human-readable report."""
|
||||
if result.valid:
|
||||
return f"✅ {result.tool_name}: valid"
|
||||
|
||||
lines = [f"❌ {result.tool_name}: BLOCKED"]
|
||||
for issue in result.blocking_issues:
|
||||
lines.append(f" [{issue.code}] {issue.message}")
|
||||
|
||||
for issue in result.warnings:
|
||||
lines.append(f" ⚠️ [{issue.code}] {issue.message}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def create_rejection_response(result: ValidationResult) -> Dict:
|
||||
"""
|
||||
Create a tool result for a rejected tool call.
|
||||
|
||||
This allows the agent to see the rejection and self-correct.
|
||||
"""
|
||||
issues_text = "\n".join(
|
||||
f"- [{i.code}] {i.message}"
|
||||
for i in result.blocking_issues
|
||||
)
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"content": f"""Tool call rejected: {result.tool_name}
|
||||
|
||||
Issues found:
|
||||
{issues_text}
|
||||
|
||||
Please check the tool name and parameters, then try again with valid arguments.""",
|
||||
}
|
||||
Reference in New Issue
Block a user