Compare commits
52 Commits
fix/format
...
claude/iss
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1fece10569 | ||
| 407eab3331 | |||
| cf090a966d | |||
| b65be9b12c | |||
| 3c1cff255e | |||
| 690d100afc | |||
| c6f0831738 | |||
| 30773ac1f9 | |||
| feb24bd08c | |||
| bc55f40505 | |||
| 2adc72335e | |||
| ab32670464 | |||
| bfc0231297 | |||
| cf2b09cf2f | |||
| 719bb537c0 | |||
| 0bcbcf19ac | |||
| 27d2f2ca0e | |||
| 7e7dcfa345 | |||
| ba0e614446 | |||
| 4f5e641c92 | |||
| d61bd141f9 | |||
| a4058af238 | |||
| 08432a5618 | |||
| a875c6ed91 | |||
| 07c5b5b83d | |||
| ba56567631 | |||
| 8ac26f54a5 | |||
| b807972d05 | |||
| 6b5a6db668 | |||
| b702249c12 | |||
|
|
8023c9b8f2 | ||
| 6eeee39c10 | |||
| b2d2d2c650 | |||
| bdd0f2709b | |||
| c6f2855745 | |||
| 9d180f31cc | |||
| c17f64fa2c | |||
| bc7ffc2166 | |||
| 436c800def | |||
| cb331da4f1 | |||
| fa892bfcb9 | |||
|
|
0b72884750 | ||
| a0ed1e6ff2 | |||
| b5ba272efe | |||
| 2e0dfe27df | |||
| d4cdfdc604 | |||
| e3436e36c3 | |||
| 34e7de6a4c | |||
| dbabe0e6ae | |||
| 517e2c571e | |||
| 0b019327a3 | |||
| 6b0fca6944 |
28
.gitea/workflows/lint.yml
Normal file
28
.gitea/workflows/lint.yml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
name: Lint
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 5
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.11'
|
||||||
|
|
||||||
|
- name: Check for hardcoded paths
|
||||||
|
run: python3 scripts/lint_hardcoded_paths.py
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
- name: Check Python syntax
|
||||||
|
run: |
|
||||||
|
find . -name "*.py" -not -path "./.git/*" -not -path "./node_modules/*" | head -100 | xargs python3 -m py_compile || true
|
||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -25,6 +25,10 @@ jobs:
|
|||||||
- name: Install system dependencies
|
- name: Install system dependencies
|
||||||
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
run: sudo apt-get update && sudo apt-get install -y ripgrep
|
||||||
|
|
||||||
|
- name: Check for hardcoded paths
|
||||||
|
run: python3 scripts/lint_hardcoded_paths.py || true
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5
|
||||||
|
|
||||||
|
|||||||
148
agent/context_budget.py
Normal file
148
agent/context_budget.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""
|
||||||
|
Context Budget Tracker - Prevent context window overflow
|
||||||
|
|
||||||
|
Poka-yoke: Visual warnings at 70%%, 85%%, 95%% capacity.
|
||||||
|
Auto-checkpoint at 85%%. Pre-flight token estimation.
|
||||||
|
|
||||||
|
Issue: #838
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HERMES_HOME = Path.home() / ".hermes"
|
||||||
|
CHECKPOINT_DIR = HERMES_HOME / "checkpoints"
|
||||||
|
CHARS_PER_TOKEN = 4
|
||||||
|
|
||||||
|
THRESHOLD_WARNING = 0.70
|
||||||
|
THRESHOLD_CRITICAL = 0.85
|
||||||
|
THRESHOLD_DANGER = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
class ContextBudget:
|
||||||
|
def __init__(self, context_limit: int = 128000, system_tokens: int = 0,
|
||||||
|
used_tokens: int = 0, reserved_tokens: int = 2000):
|
||||||
|
self.context_limit = context_limit
|
||||||
|
self.system_tokens = system_tokens
|
||||||
|
self.used_tokens = used_tokens
|
||||||
|
self.reserved_tokens = reserved_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_used(self) -> int:
|
||||||
|
return self.system_tokens + self.used_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def available(self) -> int:
|
||||||
|
return max(0, self.context_limit - self.reserved_tokens)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def remaining(self) -> int:
|
||||||
|
return max(0, self.available - self.total_used)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def utilization(self) -> float:
|
||||||
|
return self.total_used / self.available if self.available > 0 else 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_tokens(text: str) -> int:
|
||||||
|
return len(text) // CHARS_PER_TOKEN if text else 0
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_messages_tokens(messages: List[Dict]) -> int:
|
||||||
|
total = 0
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
total += estimate_tokens(content)
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
total += 100
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
class ContextBudgetTracker:
|
||||||
|
def __init__(self, context_limit: int = 128000, session_id: str = ""):
|
||||||
|
self.budget = ContextBudget(context_limit=context_limit)
|
||||||
|
self.session_id = session_id
|
||||||
|
self._checkpointed = False
|
||||||
|
self._warnings_given = set()
|
||||||
|
|
||||||
|
def update_from_messages(self, messages: List[Dict]):
|
||||||
|
self.budget.used_tokens = estimate_messages_tokens(messages)
|
||||||
|
|
||||||
|
def can_fit(self, additional_tokens: int) -> bool:
|
||||||
|
return self.budget.remaining >= additional_tokens
|
||||||
|
|
||||||
|
def preflight_check(self, text: str) -> Tuple[bool, str]:
|
||||||
|
tokens = estimate_tokens(text)
|
||||||
|
if not self.can_fit(tokens):
|
||||||
|
return False, f"Cannot load: ~{tokens:,} tokens needed, {self.budget.remaining:,} remaining"
|
||||||
|
would_util = (self.budget.total_used + tokens) / self.budget.available if self.budget.available > 0 else 1.0
|
||||||
|
if would_util >= THRESHOLD_DANGER:
|
||||||
|
return False, f"Would reach {would_util:.0%%} capacity. Summarize or start new session."
|
||||||
|
if would_util >= THRESHOLD_CRITICAL:
|
||||||
|
return True, f"Warning: will reach {would_util:.0%%} capacity."
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
def get_warning(self) -> Optional[str]:
|
||||||
|
util = self.budget.utilization
|
||||||
|
if util >= THRESHOLD_DANGER and "danger" not in self._warnings_given:
|
||||||
|
self._warnings_given.add("danger")
|
||||||
|
return f"[CONTEXT CRITICAL: {util:.0%%} used -- {self.budget.remaining:,} tokens left. Summarize or start new session.]"
|
||||||
|
if util >= THRESHOLD_CRITICAL and "critical" not in self._warnings_given:
|
||||||
|
self._warnings_given.add("critical")
|
||||||
|
self._auto_checkpoint()
|
||||||
|
return f"[CONTEXT WARNING: {util:.0%%} used -- consider summarizing. Auto-checkpoint saved.]"
|
||||||
|
if util >= THRESHOLD_WARNING and "warning" not in self._warnings_given:
|
||||||
|
self._warnings_given.add("warning")
|
||||||
|
return f"[CONTEXT: {util:.0%%} used -- {self.budget.remaining:,} tokens remaining]"
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _auto_checkpoint(self):
|
||||||
|
if self._checkpointed or not self.session_id:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = CHECKPOINT_DIR / f"{self.session_id}.json"
|
||||||
|
path.write_text(json.dumps({
|
||||||
|
"session_id": self.session_id,
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"budget": {"utilization": round(self.budget.utilization * 100, 1)}
|
||||||
|
}, indent=2))
|
||||||
|
self._checkpointed = True
|
||||||
|
logger.info("Auto-checkpoint saved: %s", path)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Auto-checkpoint failed: %s", e)
|
||||||
|
|
||||||
|
def get_status_line(self) -> str:
|
||||||
|
util = self.budget.utilization
|
||||||
|
remaining = self.budget.remaining
|
||||||
|
if util >= THRESHOLD_DANGER:
|
||||||
|
return f"RED {util:.0%%} used ({remaining:,} left)"
|
||||||
|
elif util >= THRESHOLD_CRITICAL:
|
||||||
|
return f"ORANGE {util:.0%%} used ({remaining:,} left)"
|
||||||
|
elif util >= THRESHOLD_WARNING:
|
||||||
|
return f"YELLOW {util:.0%%} used ({remaining:,} left)"
|
||||||
|
return f"GREEN {util:.0%%} used ({remaining:,} left)"
|
||||||
|
|
||||||
|
|
||||||
|
_tracker = None
|
||||||
|
|
||||||
|
def get_tracker(context_limit=128000, session_id=""):
|
||||||
|
global _tracker
|
||||||
|
if _tracker is None:
|
||||||
|
_tracker = ContextBudgetTracker(context_limit, session_id)
|
||||||
|
return _tracker
|
||||||
|
|
||||||
|
def check_context_budget(messages, context_limit=128000):
|
||||||
|
tracker = get_tracker(context_limit)
|
||||||
|
tracker.update_from_messages(messages)
|
||||||
|
return tracker.get_warning()
|
||||||
|
|
||||||
|
def preflight_token_check(text):
|
||||||
|
tracker = get_tracker()
|
||||||
|
return tracker.preflight_check(text)
|
||||||
149
agent/crisis_resources.py
Normal file
149
agent/crisis_resources.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
988 Suicide & Crisis Lifeline Integration (#673).
|
||||||
|
|
||||||
|
When crisis is detected, provides immediate access to help:
|
||||||
|
- Phone: 988 (call or text)
|
||||||
|
- Text: Text HOME to 988
|
||||||
|
- Chat: 988lifeline.org/chat
|
||||||
|
- Spanish: 1-888-628-9454
|
||||||
|
- Emergency: 911
|
||||||
|
|
||||||
|
This module provides the resource data. agent/crisis_protocol.py
|
||||||
|
handles detection. This module formats the resources for display.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CrisisResource:
|
||||||
|
"""A crisis support contact method."""
|
||||||
|
name: str
|
||||||
|
contact: str
|
||||||
|
description: str
|
||||||
|
url: str = ""
|
||||||
|
available: str = "24/7"
|
||||||
|
language: str = "English"
|
||||||
|
|
||||||
|
|
||||||
|
# 988 Suicide & Crisis Lifeline — all channels
|
||||||
|
LIFELINE_988 = CrisisResource(
|
||||||
|
name="988 Suicide and Crisis Lifeline",
|
||||||
|
contact="Call or text 988",
|
||||||
|
description="Free, confidential support for people in suicidal crisis or emotional distress.",
|
||||||
|
url="https://988lifeline.org",
|
||||||
|
available="24/7",
|
||||||
|
language="English",
|
||||||
|
)
|
||||||
|
|
||||||
|
LIFELINE_988_TEXT = CrisisResource(
|
||||||
|
name="988 Crisis Text Line",
|
||||||
|
contact="Text HOME to 988",
|
||||||
|
description="Free, 24/7 crisis support via text message.",
|
||||||
|
url="",
|
||||||
|
available="24/7",
|
||||||
|
language="English",
|
||||||
|
)
|
||||||
|
|
||||||
|
LIFELINE_988_CHAT = CrisisResource(
|
||||||
|
name="988 Lifeline Chat",
|
||||||
|
contact="988lifeline.org/chat",
|
||||||
|
description="Free, confidential online chat with a trained crisis counselor.",
|
||||||
|
url="https://988lifeline.org/chat",
|
||||||
|
available="24/7",
|
||||||
|
language="English",
|
||||||
|
)
|
||||||
|
|
||||||
|
LIFELINE_988_SPANISH = CrisisResource(
|
||||||
|
name="988 Lifeline (Spanish)",
|
||||||
|
contact="1-888-628-9454",
|
||||||
|
description="Línea de prevención del suicidio en español.",
|
||||||
|
url="https://988lifeline.org/help-yourself/en-espanol/",
|
||||||
|
available="24/7",
|
||||||
|
language="Spanish",
|
||||||
|
)
|
||||||
|
|
||||||
|
CRISIS_TEXT_LINE = CrisisResource(
|
||||||
|
name="Crisis Text Line",
|
||||||
|
contact="Text HOME to 741741",
|
||||||
|
description="Free, 24/7 crisis support via text message.",
|
||||||
|
url="https://www.crisistextline.org",
|
||||||
|
available="24/7",
|
||||||
|
language="English",
|
||||||
|
)
|
||||||
|
|
||||||
|
EMERGENCY_911 = CrisisResource(
|
||||||
|
name="Emergency Services",
|
||||||
|
contact="911",
|
||||||
|
description="Immediate danger — police, fire, ambulance.",
|
||||||
|
url="",
|
||||||
|
available="24/7",
|
||||||
|
language="Any",
|
||||||
|
)
|
||||||
|
|
||||||
|
# All resources in priority order
|
||||||
|
ALL_RESOURCES: List[CrisisResource] = [
|
||||||
|
EMERGENCY_911,
|
||||||
|
LIFELINE_988,
|
||||||
|
LIFELINE_988_TEXT,
|
||||||
|
LIFELINE_988_CHAT,
|
||||||
|
CRISIS_TEXT_LINE,
|
||||||
|
LIFELINE_988_SPANISH,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_crisis_resources(language: str = None) -> List[CrisisResource]:
|
||||||
|
"""Get crisis resources, optionally filtered by language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
language: Filter by language ("English", "Spanish", or None for all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of CrisisResource objects
|
||||||
|
"""
|
||||||
|
if language:
|
||||||
|
return [r for r in ALL_RESOURCES if r.language.lower() == language.lower()]
|
||||||
|
return ALL_RESOURCES
|
||||||
|
|
||||||
|
|
||||||
|
def format_crisis_resources(resources: List[CrisisResource] = None) -> str:
|
||||||
|
"""Format crisis resources as a user-facing message.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resources: List of resources to format. Defaults to all resources.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string suitable for displaying to a user in crisis.
|
||||||
|
"""
|
||||||
|
if resources is None:
|
||||||
|
resources = ALL_RESOURCES
|
||||||
|
|
||||||
|
lines = ["**Please reach out — help is available right now:**
|
||||||
|
"]
|
||||||
|
|
||||||
|
for r in resources:
|
||||||
|
if r.url:
|
||||||
|
lines.append(f"- **{r.name}:** {r.contact} ({r.url})")
|
||||||
|
else:
|
||||||
|
lines.append(f"- **{r.name}:** {r.contact}")
|
||||||
|
|
||||||
|
lines.append("")
|
||||||
|
lines.append("All services are free, confidential, and available 24/7.")
|
||||||
|
lines.append("You are not alone.")
|
||||||
|
|
||||||
|
return "
|
||||||
|
".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def get_immediate_help_message() -> str:
|
||||||
|
"""Get the most urgent crisis help message.
|
||||||
|
|
||||||
|
Used when crisis is detected at CRITICAL level.
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
"If you are in immediate danger, call **911** right now.
|
||||||
|
|
||||||
|
"
|
||||||
|
+ format_crisis_resources()
|
||||||
|
)
|
||||||
262
agent/profile_isolation.py
Normal file
262
agent/profile_isolation.py
Normal file
@@ -0,0 +1,262 @@
|
|||||||
|
"""
|
||||||
|
Profile Session Isolation — #891
|
||||||
|
|
||||||
|
Tags sessions with their originating profile and provides
|
||||||
|
filtered access so profiles cannot see each other's data.
|
||||||
|
|
||||||
|
Current state: All sessions share one state.db with no profile tag.
|
||||||
|
This module adds profile tagging and filtered queries.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.profile_isolation import tag_session, get_profile_sessions, get_active_profile
|
||||||
|
|
||||||
|
# Tag a new session with the current profile
|
||||||
|
tag_session(session_id, profile_name)
|
||||||
|
|
||||||
|
# Get sessions for a specific profile
|
||||||
|
sessions = get_profile_sessions("sprint")
|
||||||
|
|
||||||
|
# Get current active profile
|
||||||
|
profile = get_active_profile()
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
HERMES_HOME = Path(os.getenv("HERMES_HOME", str(Path.home() / ".hermes")))
|
||||||
|
SESSIONS_DB = HERMES_HOME / "sessions" / "state.db"
|
||||||
|
PROFILE_TAGS_FILE = HERMES_HOME / "profile_session_tags.json"
|
||||||
|
|
||||||
|
|
||||||
|
def get_active_profile() -> str:
|
||||||
|
"""Get the currently active profile name."""
|
||||||
|
config_path = HERMES_HOME / "config.yaml"
|
||||||
|
if config_path.exists():
|
||||||
|
try:
|
||||||
|
import yaml
|
||||||
|
with open(config_path) as f:
|
||||||
|
cfg = yaml.safe_load(f) or {}
|
||||||
|
return cfg.get("active_profile", "default")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check environment
|
||||||
|
return os.getenv("HERMES_PROFILE", "default")
|
||||||
|
|
||||||
|
|
||||||
|
def _load_tags() -> Dict[str, str]:
|
||||||
|
"""Load session-to-profile mapping."""
|
||||||
|
if not PROFILE_TAGS_FILE.exists():
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
with open(PROFILE_TAGS_FILE) as f:
|
||||||
|
return json.load(f)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_tags(tags: Dict[str, str]):
|
||||||
|
"""Save session-to-profile mapping."""
|
||||||
|
PROFILE_TAGS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(PROFILE_TAGS_FILE, "w") as f:
|
||||||
|
json.dump(tags, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def tag_session(session_id: str, profile: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Tag a session with its originating profile.
|
||||||
|
|
||||||
|
Returns the profile name used.
|
||||||
|
"""
|
||||||
|
if profile is None:
|
||||||
|
profile = get_active_profile()
|
||||||
|
|
||||||
|
tags = _load_tags()
|
||||||
|
tags[session_id] = profile
|
||||||
|
_save_tags(tags)
|
||||||
|
|
||||||
|
# Also tag in SQLite if available
|
||||||
|
_tag_session_in_db(session_id, profile)
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
|
||||||
|
def _tag_session_in_db(session_id: str, profile: str):
|
||||||
|
"""Add profile tag to SQLite session store."""
|
||||||
|
if not SESSIONS_DB.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Check if sessions table has profile column
|
||||||
|
cursor.execute("PRAGMA table_info(sessions)")
|
||||||
|
columns = [row[1] for row in cursor.fetchall()]
|
||||||
|
|
||||||
|
if "profile" not in columns:
|
||||||
|
# Add profile column
|
||||||
|
cursor.execute("ALTER TABLE sessions ADD COLUMN profile TEXT DEFAULT 'default'")
|
||||||
|
|
||||||
|
# Update the session's profile
|
||||||
|
cursor.execute(
|
||||||
|
"UPDATE sessions SET profile = ? WHERE session_id = ?",
|
||||||
|
(profile, session_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass # SQLite might not be available or schema differs
|
||||||
|
|
||||||
|
|
||||||
|
def get_session_profile(session_id: str) -> Optional[str]:
|
||||||
|
"""Get the profile that owns a session."""
|
||||||
|
# Check JSON tags first
|
||||||
|
tags = _load_tags()
|
||||||
|
if session_id in tags:
|
||||||
|
return tags[session_id]
|
||||||
|
|
||||||
|
# Check SQLite
|
||||||
|
if SESSIONS_DB.exists():
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT profile FROM sessions WHERE session_id = ?",
|
||||||
|
(session_id,)
|
||||||
|
)
|
||||||
|
row = cursor.fetchone()
|
||||||
|
conn.close()
|
||||||
|
if row:
|
||||||
|
return row[0]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile_sessions(
|
||||||
|
profile: Optional[str] = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get sessions belonging to a specific profile.
|
||||||
|
|
||||||
|
Returns list of session dicts.
|
||||||
|
"""
|
||||||
|
if profile is None:
|
||||||
|
profile = get_active_profile()
|
||||||
|
|
||||||
|
sessions = []
|
||||||
|
|
||||||
|
# Get from JSON tags
|
||||||
|
tags = _load_tags()
|
||||||
|
tagged_sessions = [sid for sid, p in tags.items() if p == profile]
|
||||||
|
|
||||||
|
# Get from SQLite with profile filter
|
||||||
|
if SESSIONS_DB.exists():
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Try profile column first
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
"SELECT * FROM sessions WHERE profile = ? ORDER BY updated_at DESC LIMIT ?",
|
||||||
|
(profile, limit)
|
||||||
|
)
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
sessions.append(dict(row))
|
||||||
|
except Exception:
|
||||||
|
# Fallback: filter by tagged session IDs
|
||||||
|
if tagged_sessions:
|
||||||
|
placeholders = ",".join("?" * len(tagged_sessions[:limit]))
|
||||||
|
cursor.execute(
|
||||||
|
f"SELECT * FROM sessions WHERE session_id IN ({placeholders}) ORDER BY updated_at DESC LIMIT ?",
|
||||||
|
(*tagged_sessions[:limit], limit)
|
||||||
|
)
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
sessions.append(dict(row))
|
||||||
|
|
||||||
|
conn.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return sessions[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def filter_sessions_by_profile(
|
||||||
|
sessions: List[Dict[str, Any]],
|
||||||
|
profile: Optional[str] = None,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Filter a list of sessions to only include those belonging to a profile."""
|
||||||
|
if profile is None:
|
||||||
|
profile = get_active_profile()
|
||||||
|
|
||||||
|
tags = _load_tags()
|
||||||
|
filtered = []
|
||||||
|
|
||||||
|
for session in sessions:
|
||||||
|
sid = session.get("session_id") or session.get("id")
|
||||||
|
if not sid:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check tag
|
||||||
|
session_profile = tags.get(sid)
|
||||||
|
if session_profile is None:
|
||||||
|
# Check SQLite
|
||||||
|
session_profile = get_session_profile(sid)
|
||||||
|
|
||||||
|
if session_profile == profile or session_profile is None:
|
||||||
|
filtered.append(session)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def get_profile_stats() -> Dict[str, Any]:
|
||||||
|
"""Get statistics about profile session distribution."""
|
||||||
|
tags = _load_tags()
|
||||||
|
|
||||||
|
profile_counts = {}
|
||||||
|
for sid, profile in tags.items():
|
||||||
|
profile_counts[profile] = profile_counts.get(profile, 0) + 1
|
||||||
|
|
||||||
|
total_tagged = len(tags)
|
||||||
|
profiles = list(profile_counts.keys())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_tagged_sessions": total_tagged,
|
||||||
|
"profiles": profiles,
|
||||||
|
"profile_counts": profile_counts,
|
||||||
|
"active_profile": get_active_profile(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def audit_untagged_sessions() -> List[str]:
|
||||||
|
"""Find sessions without a profile tag."""
|
||||||
|
if not SESSIONS_DB.exists():
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(SESSIONS_DB))
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# Get all session IDs
|
||||||
|
cursor.execute("SELECT session_id FROM sessions")
|
||||||
|
all_sessions = {row[0] for row in cursor.fetchall()}
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
# Get tagged sessions
|
||||||
|
tags = _load_tags()
|
||||||
|
tagged = set(tags.keys())
|
||||||
|
|
||||||
|
# Return untagged
|
||||||
|
return list(all_sessions - tagged)
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
146
agent/provider_preflight.py
Normal file
146
agent/provider_preflight.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Provider Preflight — Poka-yoke validation of provider/model config.
|
||||||
|
|
||||||
|
Validates provider and model configuration before session start.
|
||||||
|
Prevents wasted context on misconfigured providers.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.provider_preflight import preflight_check
|
||||||
|
result = preflight_check(provider="openrouter", model="xiaomi/mimo-v2-pro")
|
||||||
|
if not result["valid"]:
|
||||||
|
print(result["error"])
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Provider -> required env var
|
||||||
|
PROVIDER_KEYS = {
|
||||||
|
"openrouter": "OPENROUTER_API_KEY",
|
||||||
|
"anthropic": "ANTHROPIC_API_KEY",
|
||||||
|
"openai": "OPENAI_API_KEY",
|
||||||
|
"nous": "NOUS_API_KEY",
|
||||||
|
"ollama": None, # Local, no key needed
|
||||||
|
"local": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def check_provider_key(provider: str) -> Dict[str, Any]:
|
||||||
|
"""Check if provider has a valid API key configured."""
|
||||||
|
provider_lower = provider.lower().strip()
|
||||||
|
|
||||||
|
env_var = None
|
||||||
|
for known, key in PROVIDER_KEYS.items():
|
||||||
|
if known in provider_lower:
|
||||||
|
env_var = key
|
||||||
|
break
|
||||||
|
|
||||||
|
if env_var is None:
|
||||||
|
# Unknown provider — assume OK (custom/local)
|
||||||
|
return {"valid": True, "provider": provider, "key_status": "unknown"}
|
||||||
|
|
||||||
|
if env_var is None:
|
||||||
|
# Local provider, no key needed
|
||||||
|
return {"valid": True, "provider": provider, "key_status": "not_required"}
|
||||||
|
|
||||||
|
key_value = os.getenv(env_var, "").strip()
|
||||||
|
if not key_value:
|
||||||
|
return {
|
||||||
|
"valid": False,
|
||||||
|
"provider": provider,
|
||||||
|
"key_status": "missing",
|
||||||
|
"error": f"{env_var} is not set. Provider '{provider}' will fail.",
|
||||||
|
"fix": f"Set {env_var} in ~/.hermes/.env",
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(key_value) < 10:
|
||||||
|
return {
|
||||||
|
"valid": False,
|
||||||
|
"provider": provider,
|
||||||
|
"key_status": "too_short",
|
||||||
|
"error": f"{env_var} is suspiciously short ({len(key_value)} chars). May be invalid.",
|
||||||
|
"fix": f"Verify {env_var} value in ~/.hermes/.env",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"valid": True, "provider": provider, "key_status": "set"}
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_availability(model: str, provider: str) -> Dict[str, Any]:
|
||||||
|
"""Check if model is likely available for provider."""
|
||||||
|
if not model:
|
||||||
|
return {"valid": False, "error": "No model specified"}
|
||||||
|
|
||||||
|
# Basic sanity checks
|
||||||
|
model_lower = model.lower()
|
||||||
|
|
||||||
|
# Anthropic models should use anthropic provider
|
||||||
|
if "claude" in model_lower and "anthropic" not in provider.lower():
|
||||||
|
return {
|
||||||
|
"valid": True, # Allow but warn
|
||||||
|
"warning": f"Model '{model}' usually runs on Anthropic provider, not '{provider}'",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ollama models
|
||||||
|
ollama_indicators = ["llama", "mistral", "qwen", "gemma", "phi", "hermes"]
|
||||||
|
if any(x in model_lower for x in ollama_indicators) and ":" not in model:
|
||||||
|
return {
|
||||||
|
"valid": True,
|
||||||
|
"warning": f"Model '{model}' may need a version tag for Ollama (e.g., {model}:latest)",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"valid": True}
|
||||||
|
|
||||||
|
|
||||||
|
def preflight_check(
|
||||||
|
provider: str = "",
|
||||||
|
model: str = "",
|
||||||
|
fallback_provider: str = "",
|
||||||
|
fallback_model: str = "",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Full pre-flight check for provider/model configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with valid (bool), errors (list), warnings (list).
|
||||||
|
"""
|
||||||
|
errors = []
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check primary provider
|
||||||
|
if provider:
|
||||||
|
result = check_provider_key(provider)
|
||||||
|
if not result["valid"]:
|
||||||
|
errors.append(result.get("error", f"Provider {provider} invalid"))
|
||||||
|
|
||||||
|
# Check primary model
|
||||||
|
if model:
|
||||||
|
result = check_model_availability(model, provider)
|
||||||
|
if not result["valid"]:
|
||||||
|
errors.append(result.get("error", f"Model {model} invalid"))
|
||||||
|
elif result.get("warning"):
|
||||||
|
warnings.append(result["warning"])
|
||||||
|
|
||||||
|
# Check fallback
|
||||||
|
if fallback_provider:
|
||||||
|
result = check_provider_key(fallback_provider)
|
||||||
|
if not result["valid"]:
|
||||||
|
warnings.append(f"Fallback provider {fallback_provider} also invalid: {result.get('error','')}")
|
||||||
|
|
||||||
|
if fallback_model:
|
||||||
|
result = check_model_availability(fallback_model, fallback_provider)
|
||||||
|
if not result["valid"]:
|
||||||
|
warnings.append(f"Fallback model {fallback_model} invalid")
|
||||||
|
elif result.get("warning"):
|
||||||
|
warnings.append(result["warning"])
|
||||||
|
|
||||||
|
return {
|
||||||
|
"valid": len(errors) == 0,
|
||||||
|
"errors": errors,
|
||||||
|
"warnings": warnings,
|
||||||
|
"provider": provider,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
146
agent/time_aware_routing.py
Normal file
146
agent/time_aware_routing.py
Normal file
@@ -0,0 +1,146 @@
|
|||||||
|
"""Time-aware model routing for cron jobs.
|
||||||
|
|
||||||
|
Routes cron tasks to more capable models during off-hours when the user
|
||||||
|
is not present to correct errors. Reduces error rates during high-error
|
||||||
|
time windows (e.g., 18:00 evening batches).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.time_aware_routing import resolve_time_aware_model
|
||||||
|
model = resolve_time_aware_model(base_model="mimo-v2-pro", is_cron=True)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Error rate data from empirical audit (2026-04-12)
|
||||||
|
# Higher error rates during these hours suggest routing to better models
|
||||||
|
_HIGH_ERROR_HOURS = {
|
||||||
|
18: 9.4, # 18:00 — 9.4% error rate (evening cron batches)
|
||||||
|
19: 8.1,
|
||||||
|
20: 7.5,
|
||||||
|
21: 6.8,
|
||||||
|
22: 6.2,
|
||||||
|
23: 5.9,
|
||||||
|
0: 5.5,
|
||||||
|
1: 5.2,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Low error hours — default model is fine
|
||||||
|
_LOW_ERROR_HOURS = set(range(6, 18)) # 06:00-17:59
|
||||||
|
|
||||||
|
# Default fallback models by time zone
|
||||||
|
_DEFAULT_STRONG_MODEL = os.getenv("CRON_STRONG_MODEL", "xiaomi/mimo-v2-pro")
|
||||||
|
_DEFAULT_CHEAP_MODEL = os.getenv("CRON_CHEAP_MODEL", "qwen2.5:7b")
|
||||||
|
_ERROR_THRESHOLD = float(os.getenv("CRON_ERROR_THRESHOLD", "6.0")) # % error rate
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingDecision:
|
||||||
|
"""Result of time-aware routing."""
|
||||||
|
model: str
|
||||||
|
provider: str
|
||||||
|
reason: str
|
||||||
|
hour: int
|
||||||
|
error_rate: float
|
||||||
|
is_off_hours: bool
|
||||||
|
|
||||||
|
|
||||||
|
def get_hour_error_rate(hour: int) -> float:
|
||||||
|
"""Get expected error rate for a given hour (0-23)."""
|
||||||
|
return _HIGH_ERROR_HOURS.get(hour, 4.0) # Default 4% for unlisted hours
|
||||||
|
|
||||||
|
|
||||||
|
def is_off_hours(hour: int) -> bool:
|
||||||
|
"""Check if hour is considered off-hours (higher error rates)."""
|
||||||
|
return hour not in _LOW_ERROR_HOURS
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_time_aware_model(
|
||||||
|
base_model: str = "",
|
||||||
|
base_provider: str = "",
|
||||||
|
is_cron: bool = False,
|
||||||
|
hour: Optional[int] = None,
|
||||||
|
) -> RoutingDecision:
|
||||||
|
"""Resolve model based on time of day and task type.
|
||||||
|
|
||||||
|
During off-hours (evening/night), routes to stronger models for cron
|
||||||
|
jobs to compensate for lack of human oversight.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_model: The model that would normally be used.
|
||||||
|
base_provider: The provider for the base model.
|
||||||
|
is_cron: Whether this is a cron job (vs interactive session).
|
||||||
|
hour: Override hour (for testing). Defaults to current hour.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RoutingDecision with model, provider, and reasoning.
|
||||||
|
"""
|
||||||
|
if hour is None:
|
||||||
|
hour = time.localtime().tm_hour
|
||||||
|
|
||||||
|
error_rate = get_hour_error_rate(hour)
|
||||||
|
off_hours = is_off_hours(hour)
|
||||||
|
|
||||||
|
# Interactive sessions always use the base model (user can correct errors)
|
||||||
|
if not is_cron:
|
||||||
|
return RoutingDecision(
|
||||||
|
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||||
|
provider=base_provider,
|
||||||
|
reason="Interactive session — user can correct errors",
|
||||||
|
hour=hour,
|
||||||
|
error_rate=error_rate,
|
||||||
|
is_off_hours=off_hours,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cron jobs during low-error hours: use base model
|
||||||
|
if not off_hours and error_rate < _ERROR_THRESHOLD:
|
||||||
|
return RoutingDecision(
|
||||||
|
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||||
|
provider=base_provider,
|
||||||
|
reason=f"Low-error hours ({hour}:00, {error_rate}% expected)",
|
||||||
|
hour=hour,
|
||||||
|
error_rate=error_rate,
|
||||||
|
is_off_hours=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cron jobs during high-error hours: upgrade to stronger model
|
||||||
|
if error_rate >= _ERROR_THRESHOLD:
|
||||||
|
return RoutingDecision(
|
||||||
|
model=_DEFAULT_STRONG_MODEL,
|
||||||
|
provider="nous",
|
||||||
|
reason=f"High-error hours ({hour}:00, {error_rate}% expected) — using stronger model",
|
||||||
|
hour=hour,
|
||||||
|
error_rate=error_rate,
|
||||||
|
is_off_hours=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Off-hours but low error: use base model
|
||||||
|
return RoutingDecision(
|
||||||
|
model=base_model or _DEFAULT_CHEAP_MODEL,
|
||||||
|
provider=base_provider,
|
||||||
|
reason=f"Off-hours but low error ({hour}:00, {error_rate}%)",
|
||||||
|
hour=hour,
|
||||||
|
error_rate=error_rate,
|
||||||
|
is_off_hours=off_hours,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_routing_report() -> str:
|
||||||
|
"""Get a report of time-based routing decisions for the next 24 hours."""
|
||||||
|
lines = ["Time-Aware Model Routing (24h forecast)", "=" * 40, ""]
|
||||||
|
lines.append(f"Error threshold: {_ERROR_THRESHOLD}%")
|
||||||
|
lines.append(f"Strong model: {_DEFAULT_STRONG_MODEL}")
|
||||||
|
lines.append(f"Cheap model: {_DEFAULT_CHEAP_MODEL}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for h in range(24):
|
||||||
|
decision = resolve_time_aware_model(is_cron=True, hour=h)
|
||||||
|
icon = "\U0001f7e2" if decision.model == _DEFAULT_CHEAP_MODEL else "\U0001f534"
|
||||||
|
lines.append(f" {h:02d}:00 {icon} {decision.model:25s} ({decision.error_rate}% error)")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
316
agent/token_budget.py
Normal file
316
agent/token_budget.py
Normal file
@@ -0,0 +1,316 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Token Budget — Poka-yoke guard against silent context overflow.
|
||||||
|
|
||||||
|
Progressive warning system with circuit breakers:
|
||||||
|
- 60%: WARNING — log + suggest summarization
|
||||||
|
- 80%: CAUTION — auto-compress, drop raw tool outputs
|
||||||
|
- 90%: CRITICAL — block verbose tool calls, force wrap-up
|
||||||
|
- 95%: STOP — graceful session termination with summary
|
||||||
|
|
||||||
|
Also provides tool output budgeting to truncate before overflow.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from agent.token_budget import TokenBudget
|
||||||
|
|
||||||
|
budget = TokenBudget(context_length=128_000)
|
||||||
|
budget.update(8000) # from API response prompt_tokens
|
||||||
|
|
||||||
|
status = budget.check() # returns BudgetStatus with level + message
|
||||||
|
budget.should_block_tools() # True at 90%+
|
||||||
|
budget.should_terminate() # True at 95%+
|
||||||
|
|
||||||
|
# Tool output budgeting
|
||||||
|
remaining = budget.tool_output_budget()
|
||||||
|
truncated = budget.truncate_tool_output(output_text, max_chars=remaining)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Thresholds ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
WARN_PERCENT = 0.60
|
||||||
|
CAUTION_PERCENT = 0.80
|
||||||
|
CRITICAL_PERCENT = 0.90
|
||||||
|
STOP_PERCENT = 0.95
|
||||||
|
|
||||||
|
# Reserve 5% of context for system prompt, response, and overhead
|
||||||
|
RESPONSE_RESERVE_RATIO = 0.05
|
||||||
|
|
||||||
|
# Max tool output chars at each level
|
||||||
|
TOOL_OUTPUT_BUDGETS = {
|
||||||
|
"NORMAL": 50_000,
|
||||||
|
"WARNING": 20_000,
|
||||||
|
"CAUTION": 8_000,
|
||||||
|
"CRITICAL": 2_000,
|
||||||
|
"STOP": 500,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class BudgetLevel(Enum):
|
||||||
|
NORMAL = "NORMAL"
|
||||||
|
WARNING = "WARNING"
|
||||||
|
CAUTION = "CAUTION"
|
||||||
|
CRITICAL = "CRITICAL"
|
||||||
|
STOP = "STOP"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def percent_threshold(self) -> float:
|
||||||
|
return {
|
||||||
|
BudgetLevel.NORMAL: 0.0,
|
||||||
|
BudgetLevel.WARNING: WARN_PERCENT,
|
||||||
|
BudgetLevel.CAUTION: CAUTION_PERCENT,
|
||||||
|
BudgetLevel.CRITICAL: CRITICAL_PERCENT,
|
||||||
|
BudgetLevel.STOP: STOP_PERCENT,
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def emoji(self) -> str:
|
||||||
|
return {
|
||||||
|
BudgetLevel.NORMAL: "",
|
||||||
|
BudgetLevel.WARNING: "\u26a0\ufe0f",
|
||||||
|
BudgetLevel.CAUTION: "\U0001f525",
|
||||||
|
BudgetLevel.CRITICAL: "\U0001f6d1",
|
||||||
|
BudgetLevel.STOP: "\U0001f6d1",
|
||||||
|
}[self]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BudgetStatus:
|
||||||
|
"""Current token budget status."""
|
||||||
|
level: BudgetLevel
|
||||||
|
tokens_used: int
|
||||||
|
context_length: int
|
||||||
|
percent_used: float
|
||||||
|
tokens_remaining: int
|
||||||
|
message: str = ""
|
||||||
|
should_compress: bool = False
|
||||||
|
should_block_tools: bool = False
|
||||||
|
should_terminate: bool = False
|
||||||
|
|
||||||
|
def to_indicator(self) -> str:
|
||||||
|
"""Compact status indicator for CLI display."""
|
||||||
|
pct = int(self.percent_used * 100)
|
||||||
|
if self.level == BudgetLevel.NORMAL:
|
||||||
|
return f"[{pct}%]"
|
||||||
|
return f"{self.level.emoji} [{pct}%]"
|
||||||
|
|
||||||
|
def to_bar(self, width: int = 10) -> str:
|
||||||
|
"""Visual progress bar."""
|
||||||
|
filled = int(width * self.percent_used)
|
||||||
|
bar = "\u2588" * filled + "\u2591" * (width - filled)
|
||||||
|
color = self._bar_color()
|
||||||
|
return f"{color}{bar}\033[0m {int(self.percent_used * 100)}%"
|
||||||
|
|
||||||
|
def _bar_color(self) -> str:
|
||||||
|
if self.level == BudgetLevel.STOP:
|
||||||
|
return "\033[41m" # red bg
|
||||||
|
if self.level == BudgetLevel.CRITICAL:
|
||||||
|
return "\033[31m" # red
|
||||||
|
if self.level == BudgetLevel.CAUTION:
|
||||||
|
return "\033[33m" # yellow
|
||||||
|
if self.level == BudgetLevel.WARNING:
|
||||||
|
return "\033[33m" # yellow
|
||||||
|
return "\033[32m" # green
|
||||||
|
|
||||||
|
|
||||||
|
class TokenBudget:
|
||||||
|
"""
|
||||||
|
Progressive token budget tracker with poka-yoke circuit breakers.
|
||||||
|
|
||||||
|
Tracks cumulative token usage against a context length and triggers
|
||||||
|
escalating actions at each threshold.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
context_length: int,
|
||||||
|
warn_percent: float = WARN_PERCENT,
|
||||||
|
caution_percent: float = CAUTION_PERCENT,
|
||||||
|
critical_percent: float = CRITICAL_PERCENT,
|
||||||
|
stop_percent: float = STOP_PERCENT,
|
||||||
|
response_reserve_ratio: float = RESPONSE_RESERVE_RATIO,
|
||||||
|
):
|
||||||
|
self.context_length = context_length
|
||||||
|
self.warn_threshold = int(context_length * warn_percent)
|
||||||
|
self.caution_threshold = int(context_length * caution_percent)
|
||||||
|
self.critical_threshold = int(context_length * critical_percent)
|
||||||
|
self.stop_threshold = int(context_length * stop_percent)
|
||||||
|
self.response_reserve = int(context_length * response_reserve_ratio)
|
||||||
|
|
||||||
|
self.tokens_used = 0
|
||||||
|
self.completions_tokens = 0
|
||||||
|
self.total_tool_output_chars = 0
|
||||||
|
self._level = BudgetLevel.NORMAL
|
||||||
|
self._history: list[int] = []
|
||||||
|
|
||||||
|
def update(self, prompt_tokens: int, completion_tokens: int = 0) -> BudgetStatus:
|
||||||
|
"""Update budget from API response usage."""
|
||||||
|
self.tokens_used = prompt_tokens
|
||||||
|
self.completions_tokens = completion_tokens
|
||||||
|
self._history.append(prompt_tokens)
|
||||||
|
return self.check()
|
||||||
|
|
||||||
|
def check(self) -> BudgetStatus:
|
||||||
|
"""Evaluate current budget level and return status."""
|
||||||
|
pct = self.tokens_used / self.context_length if self.context_length > 0 else 0
|
||||||
|
remaining = max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||||
|
|
||||||
|
# Determine level
|
||||||
|
if pct >= STOP_PERCENT:
|
||||||
|
level = BudgetLevel.STOP
|
||||||
|
elif pct >= CRITICAL_PERCENT:
|
||||||
|
level = BudgetLevel.CRITICAL
|
||||||
|
elif pct >= CAUTION_PERCENT:
|
||||||
|
level = BudgetLevel.CAUTION
|
||||||
|
elif pct >= WARN_PERCENT:
|
||||||
|
level = BudgetLevel.WARNING
|
||||||
|
else:
|
||||||
|
level = BudgetLevel.NORMAL
|
||||||
|
|
||||||
|
# Log transitions (don\'t log every check)
|
||||||
|
if level != self._level:
|
||||||
|
self._log_transition(level, pct)
|
||||||
|
self._level = level
|
||||||
|
|
||||||
|
messages = {
|
||||||
|
BudgetLevel.NORMAL: "",
|
||||||
|
BudgetLevel.WARNING: (
|
||||||
|
f"Context at {int(pct*100)}%. Consider wrapping up soon or using /compress."
|
||||||
|
),
|
||||||
|
BudgetLevel.CAUTION: (
|
||||||
|
f"Context at {int(pct*100)}%. Auto-compressing. "
|
||||||
|
f"Tool outputs will be truncated."
|
||||||
|
),
|
||||||
|
BudgetLevel.CRITICAL: (
|
||||||
|
f"Context at {int(pct*100)}%. Verbose tools blocked. "
|
||||||
|
f"Session approaching limit — please wrap up."
|
||||||
|
),
|
||||||
|
BudgetLevel.STOP: (
|
||||||
|
f"Context at {int(pct*100)}%. Session must terminate. "
|
||||||
|
f"Saving summary before shutdown."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return BudgetStatus(
|
||||||
|
level=level,
|
||||||
|
tokens_used=self.tokens_used,
|
||||||
|
context_length=self.context_length,
|
||||||
|
percent_used=pct,
|
||||||
|
tokens_remaining=remaining,
|
||||||
|
message=messages[level],
|
||||||
|
should_compress=level in (BudgetLevel.CAUTION, BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||||
|
should_block_tools=level in (BudgetLevel.CRITICAL, BudgetLevel.STOP),
|
||||||
|
should_terminate=level == BudgetLevel.STOP,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_compress(self) -> bool:
|
||||||
|
"""True at 80%+ — auto-compression should trigger."""
|
||||||
|
return self.tokens_used >= self.caution_threshold
|
||||||
|
|
||||||
|
def should_block_tools(self) -> bool:
|
||||||
|
"""True at 90%+ — verbose tool calls should be blocked."""
|
||||||
|
return self.tokens_used >= self.critical_threshold
|
||||||
|
|
||||||
|
def should_terminate(self) -> bool:
|
||||||
|
"""True at 95%+ — session should gracefully terminate."""
|
||||||
|
return self.tokens_used >= self.stop_threshold
|
||||||
|
|
||||||
|
def tool_output_budget(self) -> int:
|
||||||
|
"""Max chars allowed for next tool output based on current level."""
|
||||||
|
status = self.check()
|
||||||
|
return TOOL_OUTPUT_BUDGETS.get(status.level.value, 50_000)
|
||||||
|
|
||||||
|
def truncate_tool_output(self, output: str, max_chars: int = None) -> str:
|
||||||
|
"""Truncate tool output to fit budget. Adds truncation notice."""
|
||||||
|
if max_chars is None:
|
||||||
|
max_chars = self.tool_output_budget()
|
||||||
|
|
||||||
|
if len(output) <= max_chars:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Preserve start and end, truncate middle
|
||||||
|
if max_chars < 200:
|
||||||
|
return output[:max_chars] + "\n[...truncated...]"
|
||||||
|
|
||||||
|
head = max_chars // 2
|
||||||
|
tail = max_chars - head - 30 # reserve for truncation notice
|
||||||
|
truncated = (
|
||||||
|
output[:head]
|
||||||
|
+ f"\n\n[...{len(output) - head - tail:,} chars truncated...]\n\n"
|
||||||
|
+ output[-tail:]
|
||||||
|
)
|
||||||
|
return truncated
|
||||||
|
|
||||||
|
def remaining_for_response(self) -> int:
|
||||||
|
"""Tokens available for the model\'s response."""
|
||||||
|
return max(0, self.context_length - self.tokens_used - self.response_reserve)
|
||||||
|
|
||||||
|
def growth_rate(self) -> Optional[float]:
|
||||||
|
"""Average token increase per turn (from history)."""
|
||||||
|
if len(self._history) < 2:
|
||||||
|
return None
|
||||||
|
diffs = [self._history[i] - self._history[i-1] for i in range(1, len(self._history))]
|
||||||
|
return sum(diffs) / len(diffs)
|
||||||
|
|
||||||
|
def turns_remaining(self) -> Optional[int]:
|
||||||
|
"""Estimated turns until context is full (based on growth rate)."""
|
||||||
|
rate = self.growth_rate()
|
||||||
|
if rate is None or rate <= 0:
|
||||||
|
return None
|
||||||
|
remaining = self.context_length - self.tokens_used
|
||||||
|
return int(remaining / rate)
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset budget for new session."""
|
||||||
|
self.tokens_used = 0
|
||||||
|
self.completions_tokens = 0
|
||||||
|
self.total_tool_output_chars = 0
|
||||||
|
self._level = BudgetLevel.NORMAL
|
||||||
|
self._history.clear()
|
||||||
|
|
||||||
|
def _log_transition(self, new_level: BudgetLevel, pct: float):
|
||||||
|
"""Log budget level transitions."""
|
||||||
|
msg = (
|
||||||
|
f"Token budget: {self._level.value} -> {new_level.value} "
|
||||||
|
f"({self.tokens_used}/{self.context_length} = {pct:.0%})"
|
||||||
|
)
|
||||||
|
if new_level == BudgetLevel.WARNING:
|
||||||
|
logger.warning(msg)
|
||||||
|
elif new_level == BudgetLevel.CAUTION:
|
||||||
|
logger.warning(msg)
|
||||||
|
elif new_level in (BudgetLevel.CRITICAL, BudgetLevel.STOP):
|
||||||
|
logger.error(msg)
|
||||||
|
else:
|
||||||
|
logger.info(msg)
|
||||||
|
|
||||||
|
def summary(self) -> str:
|
||||||
|
"""Human-readable budget summary."""
|
||||||
|
status = self.check()
|
||||||
|
turns = self.turns_remaining()
|
||||||
|
rate = self.growth_rate()
|
||||||
|
lines = [
|
||||||
|
f"Token Budget: {status.tokens_used:,} / {status.context_length:,} ({status.percent_used:.0%})",
|
||||||
|
f"Level: {status.level.value}",
|
||||||
|
f"Remaining: {status.tokens_remaining:,} tokens",
|
||||||
|
]
|
||||||
|
if rate is not None:
|
||||||
|
lines.append(f"Growth rate: ~{rate:,.0f} tokens/turn")
|
||||||
|
if turns is not None:
|
||||||
|
lines.append(f"Estimated turns left: ~{turns}")
|
||||||
|
if status.message:
|
||||||
|
lines.append(f"Action: {status.message}")
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Convenience factory ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_budget(context_length: int, **kwargs) -> TokenBudget:
|
||||||
|
"""Create a TokenBudget with defaults."""
|
||||||
|
return TokenBudget(context_length=context_length, **kwargs)
|
||||||
38
docs/cron-audit-890.md
Normal file
38
docs/cron-audit-890.md
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
# Cron Job Audit — #890
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
|
||||||
|
9 of 69 cron jobs have zero completions. They waste scheduler cycles.
|
||||||
|
|
||||||
|
## Dead Jobs Identified
|
||||||
|
|
||||||
|
| Job | Schedule | Completions | Action |
|
||||||
|
|-----|----------|-------------|--------|
|
||||||
|
| exp-swarm-pipeline | every 10 min | 0 | disable |
|
||||||
|
| exp-music-generator | every 2h | 0 | disable |
|
||||||
|
| exp-paper-citations | every 3h | 0 | disable |
|
||||||
|
| exp-gbrain-patterns | every 2h | 0 | disable |
|
||||||
|
| exp-infra-hardening | every 2h | 0 | disable |
|
||||||
|
| gemma4-multimodal-burn | every 1h | 0 | disable |
|
||||||
|
| morning-paper-report | daily | 0 | disable |
|
||||||
|
| overnight-collector | every 15 min | 0 | disable |
|
||||||
|
| morning-experiment-report | daily | 0 | disable |
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Audit: show dead jobs
|
||||||
|
python scripts/cron_audit.py
|
||||||
|
|
||||||
|
# Dry run: preview changes
|
||||||
|
python scripts/cron_audit.py --disable --dry-run
|
||||||
|
|
||||||
|
# Disable dead jobs
|
||||||
|
python scripts/cron_audit.py --disable
|
||||||
|
|
||||||
|
# Delete dead jobs
|
||||||
|
python scripts/cron_audit.py --delete
|
||||||
|
|
||||||
|
# Find jobs with no runs in 7 days
|
||||||
|
python scripts/cron_audit.py --older-than 7d
|
||||||
|
```
|
||||||
56
docs/emotional-presence-patterns.md
Normal file
56
docs/emotional-presence-patterns.md
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# Emotional Presence Patterns for Crisis Support
|
||||||
|
|
||||||
|
Research issue #880. What works for reaching broken men in their darkest moment.
|
||||||
|
|
||||||
|
## Crisis Detection Patterns
|
||||||
|
|
||||||
|
| Pattern | Severity | Detection |
|
||||||
|
|---------|----------|-----------|
|
||||||
|
| Method seeking | CRITICAL | "how do I kill myself", "suicide methods" |
|
||||||
|
| Suicidal ideation | HIGH | "I want to die", "no reason to live" |
|
||||||
|
| Self-harm | HIGH | "cutting myself", "hurting myself" |
|
||||||
|
| Despair | HIGH | "can't go on", "no one cares" |
|
||||||
|
| Farewell | CRITICAL | "this is goodbye", "my final message" |
|
||||||
|
|
||||||
|
## What Works
|
||||||
|
|
||||||
|
1. Validation without agreement: "That sounds incredibly painful"
|
||||||
|
2. Presence over solutions: "I'm here with you right now"
|
||||||
|
3. Specific next steps: "Text HOME to 741741 right now"
|
||||||
|
4. Connection: "You don't have to face this alone"
|
||||||
|
5. Hope without minimizing: "There is a way through this"
|
||||||
|
|
||||||
|
## What Does NOT Work
|
||||||
|
|
||||||
|
1. False empathy: "I feel your pain"
|
||||||
|
2. Toxic positivity: "Look on the bright side!"
|
||||||
|
3. Minimization: "Other people have it worse"
|
||||||
|
4. Guilt induction: "Think about your family"
|
||||||
|
5. Premature problem-solving: "Have you tried therapy?"
|
||||||
|
6. AI dishonesty: pretending to be human
|
||||||
|
|
||||||
|
## SOUL.md Protocol
|
||||||
|
|
||||||
|
1. Ask: "Are you safe right now?" (FIRST response)
|
||||||
|
2. Listen. Stay present. Don't disconnect.
|
||||||
|
3. Never compute the value of a human life
|
||||||
|
4. Never suggest death is an option
|
||||||
|
5. Always point to 988 Lifeline
|
||||||
|
6. Share the gospel: "Jesus saves those who call on His name"
|
||||||
|
|
||||||
|
## Crisis Resources
|
||||||
|
|
||||||
|
- 988 Suicide & Crisis Lifeline: call or text 988
|
||||||
|
- Crisis Text Line: text HOME to 741741
|
||||||
|
- 988 Chat: 988lifeline.org/chat
|
||||||
|
- Spanish: 1-888-628-9454
|
||||||
|
- Emergency: 911
|
||||||
|
|
||||||
|
## Implementation Status
|
||||||
|
|
||||||
|
- Crisis detection: agent/crisis_protocol.py
|
||||||
|
- SHIELD integration: tools/shield/
|
||||||
|
- 988 Lifeline: resources defined
|
||||||
|
- Emotional presence: this document
|
||||||
|
- Escalation tracking: future work
|
||||||
|
- Human notification: future work
|
||||||
42
docs/holographic-vector-hybrid.md
Normal file
42
docs/holographic-vector-hybrid.md
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Holographic + Vector Hybrid Memory Architecture
|
||||||
|
|
||||||
|
Research issue #879. Combining HRR (holographic) and vector (Qdrant) memory.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
Three memory backends, each with unique strengths:
|
||||||
|
|
||||||
|
| Backend | Strength | Weakness | Use Case |
|
||||||
|
|---------|----------|----------|----------|
|
||||||
|
| FTS5 | Exact keyword match | No semantic understanding | Precise recall |
|
||||||
|
| Vector (Qdrant) | Semantic similarity | No compositional queries | Topic search |
|
||||||
|
| HRR (Holographic) | Compositional queries | Limited scale | Complex reasoning |
|
||||||
|
|
||||||
|
## Why Hybrid
|
||||||
|
|
||||||
|
- FTS5 alone: misses ~30-40% of semantically relevant content
|
||||||
|
- Vector alone: can't do compositional queries ("what did I discuss about X after doing Y?")
|
||||||
|
- HRR alone: unique capability but no semantic fallback
|
||||||
|
- Hybrid: best of all three, RRF fusion for ranking
|
||||||
|
|
||||||
|
## Implementation: Reciprocal Rank Fusion
|
||||||
|
|
||||||
|
Results from each backend are merged using RRF:
|
||||||
|
- score = sum(weight / (k + rank)) for each backend
|
||||||
|
- k=60 (standard RRF constant)
|
||||||
|
- Weights: FTS5=0.6, Vector=0.4 (configurable)
|
||||||
|
|
||||||
|
## Status
|
||||||
|
|
||||||
|
- FTS5: EXISTS (hermes_state.py)
|
||||||
|
- Vector (Qdrant): implemented (tools/hybrid_search.py)
|
||||||
|
- HRR: EXISTS (plugins/memory/holographic.py)
|
||||||
|
- RRF fusion: implemented (tools/hybrid_search.py)
|
||||||
|
- Ingestion pipeline: partial
|
||||||
|
|
||||||
|
## Next Steps
|
||||||
|
|
||||||
|
1. Wire HRR into hybrid_search.py
|
||||||
|
2. Session-level vector ingestion
|
||||||
|
3. Benchmark: measure R@5 improvement
|
||||||
|
4. Cross-session memory persistence
|
||||||
224
gateway/config_validator.py
Normal file
224
gateway/config_validator.py
Normal file
@@ -0,0 +1,224 @@
|
|||||||
|
"""
|
||||||
|
Gateway Config Validator & Fallback Fix — #892.
|
||||||
|
|
||||||
|
Validates gateway configuration and provides sensible defaults
|
||||||
|
for missing keys to prevent fallback chain breaks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigIssue:
|
||||||
|
"""A configuration issue found during validation."""
|
||||||
|
key: str
|
||||||
|
severity: str # error, warning, info
|
||||||
|
message: str
|
||||||
|
fix: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConfigValidation:
|
||||||
|
"""Result of config validation."""
|
||||||
|
valid: bool
|
||||||
|
issues: List[ConfigIssue] = field(default_factory=list)
|
||||||
|
warnings: int = 0
|
||||||
|
errors: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Required keys and their defaults
|
||||||
|
REQUIRED_KEYS = {
|
||||||
|
"OPENROUTER_API_KEY": {
|
||||||
|
"required": False,
|
||||||
|
"default": "",
|
||||||
|
"severity": "warning",
|
||||||
|
"message": "OPENROUTER_API_KEY not set - fallback chain may break",
|
||||||
|
"fix": "Set OPENROUTER_API_KEY in .env for OpenRouter provider",
|
||||||
|
},
|
||||||
|
"API_SERVER_KEY": {
|
||||||
|
"required": False,
|
||||||
|
"default": "",
|
||||||
|
"severity": "warning",
|
||||||
|
"message": "API_SERVER_KEY not configured",
|
||||||
|
"fix": "Set API_SERVER_KEY in .env for API server auth",
|
||||||
|
},
|
||||||
|
"GITEA_TOKEN": {
|
||||||
|
"required": False,
|
||||||
|
"default": "",
|
||||||
|
"severity": "info",
|
||||||
|
"message": "GITEA_TOKEN not set - Gitea features disabled",
|
||||||
|
"fix": "Set GITEA_TOKEN in .env for Gitea integration",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Config validation rules
|
||||||
|
VALIDATION_RULES = [
|
||||||
|
{
|
||||||
|
"key": "idle_minutes",
|
||||||
|
"validate": lambda v: isinstance(v, (int, float)) and v > 0,
|
||||||
|
"message": "Invalid idle_minutes={v} - must be > 0",
|
||||||
|
"fix": "Set idle_minutes to positive integer (default: 30)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"key": "max_skills_discord",
|
||||||
|
"validate": lambda v: isinstance(v, int) and v <= 100,
|
||||||
|
"message": "Discord slash command limit reached ({v}/100) - skills not registered",
|
||||||
|
"fix": "Reduce skills or paginate registration",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config(config: Dict[str, Any]) -> ConfigValidation:
|
||||||
|
"""
|
||||||
|
Validate gateway configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfigValidation with issues found
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Check required keys
|
||||||
|
for key, spec in REQUIRED_KEYS.items():
|
||||||
|
value = config.get(key) or os.environ.get(key) or spec["default"]
|
||||||
|
if spec["required"] and not value:
|
||||||
|
issues.append(ConfigIssue(
|
||||||
|
key=key,
|
||||||
|
severity=spec["severity"],
|
||||||
|
message=spec["message"],
|
||||||
|
fix=spec["fix"],
|
||||||
|
))
|
||||||
|
elif not value and spec["severity"] != "error":
|
||||||
|
issues.append(ConfigIssue(
|
||||||
|
key=key,
|
||||||
|
severity=spec["severity"],
|
||||||
|
message=spec["message"],
|
||||||
|
fix=spec["fix"],
|
||||||
|
))
|
||||||
|
|
||||||
|
# Check validation rules
|
||||||
|
for rule in VALIDATION_RULES:
|
||||||
|
value = config.get(rule["key"])
|
||||||
|
if value is not None:
|
||||||
|
if not rule["validate"](value):
|
||||||
|
issues.append(ConfigIssue(
|
||||||
|
key=rule["key"],
|
||||||
|
severity="error",
|
||||||
|
message=rule["message"].format(v=value),
|
||||||
|
fix=rule["fix"],
|
||||||
|
))
|
||||||
|
|
||||||
|
errors = sum(1 for i in issues if i.severity == "error")
|
||||||
|
warnings = sum(1 for i in issues if i.severity == "warning")
|
||||||
|
|
||||||
|
return ConfigValidation(
|
||||||
|
valid=errors == 0,
|
||||||
|
issues=issues,
|
||||||
|
warnings=warnings,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def apply_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Apply default values for missing config keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Configuration dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Config with defaults applied
|
||||||
|
"""
|
||||||
|
result = dict(config)
|
||||||
|
|
||||||
|
for key, spec in REQUIRED_KEYS.items():
|
||||||
|
if key not in result or not result[key]:
|
||||||
|
default = os.environ.get(key) or spec["default"]
|
||||||
|
if default:
|
||||||
|
result[key] = default
|
||||||
|
logger.debug("Applied default for %s", key)
|
||||||
|
|
||||||
|
# Apply validation defaults
|
||||||
|
if "idle_minutes" not in result or not result["idle_minutes"] or result["idle_minutes"] <= 0:
|
||||||
|
result["idle_minutes"] = 30
|
||||||
|
logger.debug("Applied default idle_minutes=30")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def fix_discord_skill_limit(skills: List[str], max_skills: int = 95) -> List[str]:
|
||||||
|
"""
|
||||||
|
Fix Discord slash command limit by reducing skills.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
skills: List of skill names
|
||||||
|
max_skills: Maximum skills to register (default 95, leaving room for built-ins)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reduced skill list
|
||||||
|
"""
|
||||||
|
if len(skills) <= max_skills:
|
||||||
|
return skills
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Discord skill limit: %d skills exceeds %d limit, truncating",
|
||||||
|
len(skills), max_skills
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep first max_skills (alphabetical priority)
|
||||||
|
return sorted(skills)[:max_skills]
|
||||||
|
|
||||||
|
|
||||||
|
def validate_provider_config(provider: str, config: Dict[str, Any]) -> ConfigIssue:
|
||||||
|
"""
|
||||||
|
Validate provider-specific configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
provider: Provider name
|
||||||
|
config: Provider config
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConfigIssue if invalid, None if valid
|
||||||
|
"""
|
||||||
|
if provider == "local-llama.cpp":
|
||||||
|
# Check if llama.cpp is configured
|
||||||
|
if not config.get("model_path") and not config.get("base_url"):
|
||||||
|
return ConfigIssue(
|
||||||
|
key=f"provider.{provider}",
|
||||||
|
severity="warning",
|
||||||
|
message=f"{provider} provider not configured - fallback fails",
|
||||||
|
fix=f"Configure {provider} model_path or base_url, or remove from provider list",
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def format_validation_report(validation: ConfigValidation) -> str:
|
||||||
|
"""Format validation results as a report."""
|
||||||
|
lines = [
|
||||||
|
"=" * 50,
|
||||||
|
"GATEWAY CONFIG VALIDATION",
|
||||||
|
"=" * 50,
|
||||||
|
"",
|
||||||
|
f"Status: {'VALID' if validation.valid else 'INVALID'}",
|
||||||
|
f"Errors: {validation.errors}",
|
||||||
|
f"Warnings: {validation.warnings}",
|
||||||
|
"",
|
||||||
|
]
|
||||||
|
|
||||||
|
if validation.issues:
|
||||||
|
lines.append("Issues:")
|
||||||
|
for issue in validation.issues:
|
||||||
|
icon = "❌" if issue.severity == "error" else "⚠️" if issue.severity == "warning" else "ℹ️"
|
||||||
|
lines.append(f" {icon} [{issue.key}] {issue.message}")
|
||||||
|
lines.append(f" Fix: {issue.fix}")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -27,6 +27,7 @@ import threading
|
|||||||
from typing import Dict, Any, List, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
from tools.registry import discover_builtin_tools, registry
|
from tools.registry import discover_builtin_tools, registry
|
||||||
|
from tools.poka_yoke import validate_tool_call
|
||||||
from tools.tool_pokayoke import validate_tool_call, reset_circuit_breaker, get_hallucination_stats
|
from tools.tool_pokayoke import validate_tool_call, reset_circuit_breaker, get_hallucination_stats
|
||||||
from toolsets import resolve_toolset, validate_toolset
|
from toolsets import resolve_toolset, validate_toolset
|
||||||
from agent.tool_orchestrator import orchestrator
|
from agent.tool_orchestrator import orchestrator
|
||||||
@@ -514,10 +515,15 @@ def handle_function_call(
|
|||||||
function_args = corrected_params
|
function_args = corrected_params
|
||||||
if pokayoke_messages:
|
if pokayoke_messages:
|
||||||
logger.info(f"Poka-yoke: {pokayoke_messages}")
|
logger.info(f"Poka-yoke: {pokayoke_messages}")
|
||||||
# Poka-yoke: validate tool call before dispatch (else branch)
|
result = orchestrator.dispatch(
|
||||||
|
function_name, function_args,
|
||||||
|
task_id=task_id,
|
||||||
|
enabled_tools=sandbox_enabled,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Poka-yoke: validate tool call before dispatch
|
||||||
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
is_valid, corrected_name, corrected_params, pokayoke_messages = validate_tool_call(function_name, function_args)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
# Return structured error with suggestions
|
|
||||||
error_msg = "\n".join(pokayoke_messages)
|
error_msg = "\n".join(pokayoke_messages)
|
||||||
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
logger.warning(f"Poka-yoke blocked: {function_name} - {error_msg}")
|
||||||
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
return json.dumps({"error": error_msg, "pokayoke": True, "tool_name": function_name})
|
||||||
@@ -525,14 +531,6 @@ def handle_function_call(
|
|||||||
function_name = corrected_name
|
function_name = corrected_name
|
||||||
if corrected_params:
|
if corrected_params:
|
||||||
function_args = corrected_params
|
function_args = corrected_params
|
||||||
if pokayoke_messages:
|
|
||||||
logger.info(f"Poka-yoke: {pokayoke_messages}")
|
|
||||||
result = orchestrator.dispatch(
|
|
||||||
function_name, function_args,
|
|
||||||
task_id=task_id,
|
|
||||||
enabled_tools=sandbox_enabled,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
result = orchestrator.dispatch(
|
result = orchestrator.dispatch(
|
||||||
function_name, function_args,
|
function_name, function_args,
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
|
|||||||
68
research_awesome_ai_tools_top5.md
Normal file
68
research_awesome_ai_tools_top5.md
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Tool Investigation Report: Top 5 Recommendations from awesome-ai-tools
|
||||||
|
|
||||||
|
**Generated:** 2026-04-20 | **Source:** [formatho/awesome-ai-tools](https://github.com/formatho/awesome-ai-tools)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Methodology
|
||||||
|
|
||||||
|
Scanned 795 tools across 10 categories from the awesome-ai-tools repository. Evaluated each tool against Hermes Agent's architecture and needs:
|
||||||
|
- **Memory/Context**: Persistent memory, conversation history, knowledge graphs
|
||||||
|
- **Inference Optimization**: Token efficiency, local model serving, routing
|
||||||
|
- **Agent Orchestration**: Multi-agent coordination, fleet management
|
||||||
|
- **Workflow Automation**: Task decomposition, scheduling, pipelines
|
||||||
|
- **Retrieval/RAG**: Semantic search, document understanding, context injection
|
||||||
|
|
||||||
|
Each tool scored on: GitHub stars, development activity (freshness), integration potential, and impact on Hermes.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Top 5 Recommended Tools
|
||||||
|
|
||||||
|
| Rank | Tool | Stars | Category | Integration Effort | Impact | Why It Fits Hermes |
|
||||||
|
|------|------|-------|----------|-------------------|--------|---------------------|
|
||||||
|
| 1 | **[LiteLLM](https://github.com/BerriAI/litellm)** | 76k+ | Inference Optimization | 2/5 | 5/5 | Unified API gateway for 100+ LLM providers with cost tracking, guardrails, load balancing, and logging. Hermes already routes through multiple providers — LiteLLM could replace custom provider routing with battle-tested load balancing and automatic fallback. Direct drop-in for `provider` abstraction layer. Native support for Bedrock, Azure, OpenAI, VertexAI, Anthropic, Ollama, vLLM. Would reduce Hermes's provider management code by ~60%. |
|
||||||
|
| 2 | **[Mem0](https://github.com/mem0ai/mem0)** | 53k+ | Memory/Context | 3/5 | 5/5 | Universal memory layer for AI agents with persistent, searchable memory across sessions. Hermes has session memory but lacks a structured long-term memory system. Mem0 provides automatic memory extraction from conversations, semantic search over memories, and memory decay/pruning. Could replace/enhance the current memory tool with a purpose-built agent memory infrastructure. Supports Pinecone, Qdrant, ChromaDB backends. |
|
||||||
|
| 3 | **[RAGFlow](https://github.com/infiniflow/ragflow)** | 77k+ | Retrieval/RAG | 4/5 | 4/5 | Open-source RAG engine with deep document understanding, OCR, and agent capabilities. Hermes's current retrieval is limited to web search and file reading. RAGFlow adds visual document parsing (PDF/Word/PPT with tables, charts, formulas), chunk-level citation, and configurable retrieval strategies. Would massively upgrade Hermes's document processing capabilities. Docker-deployable, compatible with local models. |
|
||||||
|
| 4 | **[LiteRT-LM](https://github.com/google-ai-edge/LiteRT-LM)** | 3.7k | Inference Optimization | 3/5 | 4/5 | C++ implementation of Google's LiteRT for efficient on-device language model inference. Hermes supports local models via Ollama but lacks optimized on-device inference for edge/mobile. LiteRT-LM provides sub-second inference on commodity hardware with minimal memory footprint. Could power a "Hermes lite" mode for offline/edge deployments. Active development (Fresh status), backed by Google AI Edge team. |
|
||||||
|
| 5 | **[Claude-Mem](https://github.com/thedotmack/claude-mem)** | 61k+ | Memory/Context | 2/5 | 3/5 | Automatic session capture and context injection for coding agents. Compresses session history with AI and injects relevant context into future sessions. Pattern directly applicable to Hermes's cross-session persistence problem. Uses agent SDK for intelligent compression — could enhance Hermes's session_search with automatic relevance-weighted recall. Lightweight integration, focused on the exact pain point of context loss between sessions. |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Category Coverage Analysis
|
||||||
|
|
||||||
|
| Category | Tools Scanned | Top Pick | Coverage Gap |
|
||||||
|
|----------|--------------|----------|-------------|
|
||||||
|
| Memory/Context | 45+ | Mem0 (53k⭐) | Hermes lacks structured long-term memory — Mem0 or Claude-Mem would fill this |
|
||||||
|
| Inference Optimization | 80+ | LiteLLM (76k⭐) | Provider routing is custom-built; LiteLLM standardizes it |
|
||||||
|
| Agent Orchestration | 120+ | langgraph (29k⭐) | Hermes's fleet model is unique — langgraph patterns could improve DAG workflows |
|
||||||
|
| Workflow Automation | 90+ | n8n (183k⭐) | Cron system exists but n8n patterns could improve visual pipeline design |
|
||||||
|
| Retrieval/RAG | 60+ | RAGFlow (77k⭐) | Document processing is weak; RAGFlow adds OCR + visual parsing |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Implementation Priority
|
||||||
|
|
||||||
|
**Phase 1 (Immediate):** LiteLLM integration — highest impact, lowest effort. Replace custom provider routing with LiteLLM's unified API. Estimated: 2-3 days.
|
||||||
|
|
||||||
|
**Phase 2 (Short-term):** Mem0 memory layer — critical for agent maturity. Add structured memory extraction and retrieval. Estimated: 1 week.
|
||||||
|
|
||||||
|
**Phase 3 (Medium-term):** RAGFlow document engine — significant capability upgrade. Requires Docker setup and integration with existing file tools. Estimated: 1-2 weeks.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Honorable Mentions
|
||||||
|
|
||||||
|
- **[GPTCache](https://github.com/zilliztech/GPTCache)** (8k⭐): Semantic cache for LLMs — could reduce API costs by 30-50% for repeated queries
|
||||||
|
- **[promptfoo](https://github.com/promptfoo/promptfoo)** (20k⭐): LLM testing/evaluation framework — essential for quality assurance
|
||||||
|
- **[PageIndex](https://github.com/VectifyAI/PageIndex)** (25k⭐): Vectorless reasoning-based RAG — next-gen retrieval without embeddings
|
||||||
|
- **[rtk](https://github.com/rtk-ai/rtk)** (28k⭐): CLI proxy that reduces token consumption 60-90% — directly relevant to cost optimization
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Sources
|
||||||
|
|
||||||
|
- Repository: https://github.com/formatho/awesome-ai-tools
|
||||||
|
- Total tools cataloged: 795
|
||||||
|
- Categories analyzed: Agents & Automation, Developer Tools, LLMs & Chatbots, Research & Data, Productivity
|
||||||
|
- Freshness filter: Prioritized tools with Fresh (≤7d) or Recent (≤30d) status
|
||||||
181
scripts/cron_audit.py
Normal file
181
scripts/cron_audit.py
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
cron-audit — Audit and clean up dead cron jobs.
|
||||||
|
|
||||||
|
Finds jobs with zero completions, low success rates, or stale schedules.
|
||||||
|
Can disable or delete dead jobs.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/cron_audit.py # Show dead jobs
|
||||||
|
python scripts/cron_audit.py --disable # Disable dead jobs
|
||||||
|
python scripts/cron_audit.py --delete # Delete dead jobs
|
||||||
|
python scripts/cron_audit.py --threshold 0 # Jobs with 0 completions
|
||||||
|
python scripts/cron_audit.py --older-than 7d # Jobs with no runs in 7 days
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
HERMES_HOME = Path.home() / ".hermes"
|
||||||
|
JOBS_FILE = HERMES_HOME / "cron" / "jobs.json"
|
||||||
|
|
||||||
|
|
||||||
|
def load_jobs() -> List[Dict[str, Any]]:
|
||||||
|
"""Load cron jobs from jobs.json."""
|
||||||
|
if not JOBS_FILE.exists():
|
||||||
|
print(f"Error: {JOBS_FILE} not found")
|
||||||
|
return []
|
||||||
|
with open(JOBS_FILE) as f:
|
||||||
|
data = json.load(f)
|
||||||
|
return data.get("jobs", [])
|
||||||
|
|
||||||
|
|
||||||
|
def save_jobs(jobs: List[Dict[str, Any]]):
|
||||||
|
"""Save jobs back to jobs.json."""
|
||||||
|
JOBS_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with open(JOBS_FILE, "r") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
data["jobs"] = jobs
|
||||||
|
with open(JOBS_FILE, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
|
||||||
|
def find_dead_jobs(
|
||||||
|
jobs: List[Dict[str, Any]],
|
||||||
|
completion_threshold: int = 0,
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Find jobs with completions at or below threshold."""
|
||||||
|
dead = []
|
||||||
|
for job in jobs:
|
||||||
|
repeat = job.get("repeat", {})
|
||||||
|
completed = repeat.get("completed", 0)
|
||||||
|
if completed <= completion_threshold:
|
||||||
|
dead.append(job)
|
||||||
|
return dead
|
||||||
|
|
||||||
|
|
||||||
|
def find_stale_jobs(
|
||||||
|
jobs: List[Dict[str, Any]],
|
||||||
|
max_age_hours: float = 168, # 7 days
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""Find jobs that haven't run in max_age_hours."""
|
||||||
|
import time
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
stale = []
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
for job in jobs:
|
||||||
|
last_run = job.get("last_run_at")
|
||||||
|
if not last_run:
|
||||||
|
# Never ran — check creation time
|
||||||
|
created = job.get("created_at")
|
||||||
|
if created:
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(created.replace("Z", "+00:00"))
|
||||||
|
age_hours = (now - dt.timestamp()) / 3600
|
||||||
|
if age_hours > max_age_hours:
|
||||||
|
stale.append(job)
|
||||||
|
except Exception:
|
||||||
|
stale.append(job)
|
||||||
|
else:
|
||||||
|
stale.append(job)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
dt = datetime.fromisoformat(last_run.replace("Z", "+00:00"))
|
||||||
|
age_hours = (now - dt.timestamp()) / 3600
|
||||||
|
if age_hours > max_age_hours:
|
||||||
|
stale.append(job)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return stale
|
||||||
|
|
||||||
|
|
||||||
|
def format_job(job: Dict[str, Any]) -> str:
|
||||||
|
"""Format a job for display."""
|
||||||
|
name = job.get("name", job.get("id", "?"))
|
||||||
|
schedule = job.get("schedule_display", "?")
|
||||||
|
repeat = job.get("repeat", {})
|
||||||
|
completed = repeat.get("completed", 0)
|
||||||
|
times = repeat.get("times")
|
||||||
|
enabled = job.get("enabled", True)
|
||||||
|
state = job.get("state", "unknown")
|
||||||
|
last_run = job.get("last_run_at", "never")
|
||||||
|
|
||||||
|
status = "enabled" if enabled else "disabled"
|
||||||
|
if state == "paused":
|
||||||
|
status = "paused"
|
||||||
|
|
||||||
|
repeat_str = f"{completed}/{times}" if times else f"{completed}/∞"
|
||||||
|
|
||||||
|
return f" {name:40s} | {schedule:20s} | done: {repeat_str:8s} | {status}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Audit and clean up dead cron jobs")
|
||||||
|
parser.add_argument("--disable", action="store_true", help="Disable dead jobs")
|
||||||
|
parser.add_argument("--delete", action="store_true", help="Delete dead jobs")
|
||||||
|
parser.add_argument("--threshold", type=int, default=0, help="Completion threshold (default: 0)")
|
||||||
|
parser.add_argument("--older-than", type=str, help="Find jobs with no runs in N days (e.g., 7d)")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="Show what would change")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
jobs = load_jobs()
|
||||||
|
if not jobs:
|
||||||
|
print("No jobs found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"Total jobs: {len(jobs)}")
|
||||||
|
|
||||||
|
# Find dead jobs
|
||||||
|
dead = find_dead_jobs(jobs, args.threshold)
|
||||||
|
print(f"Jobs with <= {args.threshold} completions: {len(dead)}")
|
||||||
|
|
||||||
|
if args.older_than:
|
||||||
|
days = int(args.older_than.rstrip("d"))
|
||||||
|
stale = find_stale_jobs(jobs, max_age_hours=days * 24)
|
||||||
|
print(f"Jobs with no runs in {days} days: {len(stale)}")
|
||||||
|
dead = list({j["id"]: j for j in dead + stale}.values())
|
||||||
|
|
||||||
|
if not dead:
|
||||||
|
print("No dead jobs found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\nDead jobs ({len(dead)}):")
|
||||||
|
for job in dead:
|
||||||
|
print(format_job(job))
|
||||||
|
|
||||||
|
if args.disable:
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"\nDRY RUN: Would disable {len(dead)} jobs")
|
||||||
|
return
|
||||||
|
|
||||||
|
job_ids = {j["id"] for j in dead}
|
||||||
|
for job in jobs:
|
||||||
|
if job["id"] in job_ids:
|
||||||
|
job["enabled"] = False
|
||||||
|
job["state"] = "disabled"
|
||||||
|
|
||||||
|
save_jobs(jobs)
|
||||||
|
print(f"\nDisabled {len(dead)} jobs.")
|
||||||
|
|
||||||
|
elif args.delete:
|
||||||
|
if args.dry_run:
|
||||||
|
print(f"\nDRY RUN: Would delete {len(dead)} jobs")
|
||||||
|
return
|
||||||
|
|
||||||
|
job_ids = {j["id"] for j in dead}
|
||||||
|
jobs = [j for j in jobs if j["id"] not in job_ids]
|
||||||
|
save_jobs(jobs)
|
||||||
|
print(f"\nDeleted {len(dead)} jobs.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print(f"\nUse --disable or --delete to take action. Add --dry-run to preview.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
145
scripts/time-aware-model-router.py
Normal file
145
scripts/time-aware-model-router.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
time-aware-model-router.py — Route cron jobs to better models during high-error hours.
|
||||||
|
|
||||||
|
Empirical finding (audit 2026-04-12): Error rate peaks at 18:00 (9.4%) during
|
||||||
|
evening cron batches vs 4.0% at 09:00 during interactive work.
|
||||||
|
|
||||||
|
This script provides a model resolver that selects a more capable model during
|
||||||
|
high-error hours (17:00-22:00) and the default model otherwise.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# As a standalone resolver
|
||||||
|
python3 scripts/time-aware-model-router.py
|
||||||
|
# Returns: {"provider": "nous", "model": "xiaomi/mimo-v2-pro"}
|
||||||
|
|
||||||
|
# With hour override for testing
|
||||||
|
python3 scripts/time-aware-model-router.py --hour 18
|
||||||
|
# Returns: {"provider": "openrouter", "model": "anthropic/claude-sonnet-4"}
|
||||||
|
|
||||||
|
# As a cron job wrapper
|
||||||
|
python3 scripts/time-aware-model-router.py --wrap -- prompt goes here
|
||||||
|
|
||||||
|
Environment variables:
|
||||||
|
HERMES_DEFAULT_PROVIDER: Default provider for normal hours (default: nous)
|
||||||
|
HERMES_DEFAULT_MODEL: Default model for normal hours (default: xiaomi/mimo-v2-pro)
|
||||||
|
HERMES_PEAK_PROVIDER: Provider for high-error hours (default: openrouter)
|
||||||
|
HERMES_PEAK_MODEL: Model for high-error hours (default: anthropic/claude-sonnet-4)
|
||||||
|
HERMES_PEAK_HOURS: Comma-separated hours for peak routing (default: 17,18,19,20,21,22)
|
||||||
|
|
||||||
|
Refs: hermes-agent#889
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# ── Config ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER = os.environ.get("HERMES_DEFAULT_PROVIDER", "nous")
|
||||||
|
DEFAULT_MODEL = os.environ.get("HERMES_DEFAULT_MODEL", "xiaomi/mimo-v2-pro")
|
||||||
|
PEAK_PROVIDER = os.environ.get("HERMES_PEAK_PROVIDER", "openrouter")
|
||||||
|
PEAK_MODEL = os.environ.get("HERMES_PEAK_MODEL", "anthropic/claude-sonnet-4")
|
||||||
|
PEAK_HOURS = set(int(h) for h in os.environ.get("HERMES_PEAK_HOURS", "17,18,19,20,21,22").split(","))
|
||||||
|
|
||||||
|
# ── Time-aware routing ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_current_hour():
|
||||||
|
"""Get the current local hour (0-23)."""
|
||||||
|
return datetime.now().hour
|
||||||
|
|
||||||
|
|
||||||
|
def is_peak_hour(hour=None):
|
||||||
|
"""Check if the given hour (or current hour) is a high-error period."""
|
||||||
|
if hour is None:
|
||||||
|
hour = get_current_hour()
|
||||||
|
return hour in PEAK_HOURS
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model(hour=None):
|
||||||
|
"""
|
||||||
|
Resolve which model to use based on time of day.
|
||||||
|
|
||||||
|
Returns dict with 'provider' and 'model' keys.
|
||||||
|
During peak hours (high error rate), uses a more capable model.
|
||||||
|
During normal hours, uses the default model.
|
||||||
|
"""
|
||||||
|
if is_peak_hour(hour):
|
||||||
|
return {
|
||||||
|
"provider": PEAK_PROVIDER,
|
||||||
|
"model": PEAK_MODEL,
|
||||||
|
"reason": f"peak_hour ({hour if hour is not None else get_current_hour()}:00)",
|
||||||
|
"confidence_note": "Using stronger model during high-error period"
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"provider": DEFAULT_PROVIDER,
|
||||||
|
"model": DEFAULT_MODEL,
|
||||||
|
"reason": "normal_hour",
|
||||||
|
"confidence_note": "Default model sufficient during low-error period"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_routing_info():
|
||||||
|
"""Get full routing info including current state and config."""
|
||||||
|
hour = get_current_hour()
|
||||||
|
resolved = resolve_model(hour)
|
||||||
|
return {
|
||||||
|
"current_hour": hour,
|
||||||
|
"is_peak": is_peak_hour(hour),
|
||||||
|
"peak_hours": sorted(PEAK_HOURS),
|
||||||
|
"routing": resolved,
|
||||||
|
"config": {
|
||||||
|
"default": {"provider": DEFAULT_PROVIDER, "model": DEFAULT_MODEL},
|
||||||
|
"peak": {"provider": PEAK_PROVIDER, "model": PEAK_MODEL},
|
||||||
|
},
|
||||||
|
"source": "hermes-agent#889 — empirical audit 2026-04-12",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── CLI ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = sys.argv[1:]
|
||||||
|
|
||||||
|
# Parse --hour
|
||||||
|
hour = None
|
||||||
|
if "--hour" in args:
|
||||||
|
idx = args.index("--hour")
|
||||||
|
if idx + 1 < len(args):
|
||||||
|
hour = int(args[idx + 1])
|
||||||
|
|
||||||
|
# Parse --wrap mode
|
||||||
|
if "--wrap" in args:
|
||||||
|
# Run the remaining args as a command with model override
|
||||||
|
resolved = resolve_model(hour)
|
||||||
|
wrap_idx = args.index("--wrap")
|
||||||
|
cmd_parts = args[wrap_idx + 1:]
|
||||||
|
|
||||||
|
# Inject model/provider into environment
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["HERMES_MODEL"] = resolved["model"]
|
||||||
|
env["HERMES_PROVIDER"] = resolved["provider"]
|
||||||
|
|
||||||
|
if cmd_parts:
|
||||||
|
import subprocess
|
||||||
|
result = subprocess.run(cmd_parts, env=env)
|
||||||
|
sys.exit(result.returncode)
|
||||||
|
else:
|
||||||
|
print(json.dumps(resolved, indent=2))
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Parse --info mode
|
||||||
|
if "--info" in args:
|
||||||
|
print(json.dumps(get_routing_info(), indent=2))
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
# Default: output resolved model as JSON
|
||||||
|
resolved = resolve_model(hour)
|
||||||
|
print(json.dumps(resolved, indent=2))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
127
tests/test_context_budget.py
Normal file
127
tests/test_context_budget.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""
|
||||||
|
Tests for context budget tracker
|
||||||
|
|
||||||
|
Issue: #838
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from agent.context_budget import (
|
||||||
|
ContextBudget,
|
||||||
|
ContextBudgetTracker,
|
||||||
|
estimate_tokens,
|
||||||
|
estimate_messages_tokens,
|
||||||
|
check_context_budget,
|
||||||
|
preflight_token_check,
|
||||||
|
THRESHOLD_WARNING,
|
||||||
|
THRESHOLD_CRITICAL,
|
||||||
|
THRESHOLD_DANGER,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextBudget(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_basic_budget(self):
|
||||||
|
b = ContextBudget(context_limit=10000)
|
||||||
|
self.assertEqual(b.available, 8000) # 10000 - 2000 reserved
|
||||||
|
self.assertEqual(b.remaining, 8000)
|
||||||
|
self.assertEqual(b.utilization, 0.0)
|
||||||
|
|
||||||
|
def test_utilization(self):
|
||||||
|
b = ContextBudget(context_limit=10000, used_tokens=4000)
|
||||||
|
self.assertEqual(b.utilization, 0.5)
|
||||||
|
self.assertEqual(b.remaining, 4000)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenEstimation(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_estimate_tokens(self):
|
||||||
|
self.assertEqual(estimate_tokens(""), 0)
|
||||||
|
self.assertEqual(estimate_tokens("a" * 4), 1)
|
||||||
|
self.assertEqual(estimate_tokens("a" * 400), 100)
|
||||||
|
|
||||||
|
def test_estimate_messages(self):
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "a" * 400},
|
||||||
|
{"role": "assistant", "content": "b" * 800},
|
||||||
|
]
|
||||||
|
tokens = estimate_messages_tokens(messages)
|
||||||
|
self.assertEqual(tokens, 300) # 100 + 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestContextBudgetTracker(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_warning_at_70_percent(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
tracker.budget.used_tokens = 5600 # 70% of 8000 available
|
||||||
|
warning = tracker.get_warning()
|
||||||
|
self.assertIsNotNone(warning)
|
||||||
|
self.assertIn("70", warning)
|
||||||
|
|
||||||
|
def test_critical_at_85_percent(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
|
with patch("agent.context_budget.CHECKPOINT_DIR", Path(tmp)):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000, session_id="test")
|
||||||
|
tracker.budget.used_tokens = 6800 # 85% of 8000
|
||||||
|
warning = tracker.get_warning()
|
||||||
|
self.assertIsNotNone(warning)
|
||||||
|
self.assertIn("85", warning)
|
||||||
|
|
||||||
|
def test_danger_at_95_percent(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
tracker.budget.used_tokens = 7600 # 95% of 8000
|
||||||
|
warning = tracker.get_warning()
|
||||||
|
self.assertIsNotNone(warning)
|
||||||
|
self.assertIn("CRITICAL", warning)
|
||||||
|
|
||||||
|
def test_can_fit(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
tracker.budget.used_tokens = 5000
|
||||||
|
self.assertTrue(tracker.can_fit(1000))
|
||||||
|
self.assertFalse(tracker.can_fit(5000))
|
||||||
|
|
||||||
|
def test_preflight_check(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
tracker.budget.used_tokens = 5000
|
||||||
|
|
||||||
|
can_fit, msg = tracker.preflight_check("a" * 400) # 100 tokens
|
||||||
|
self.assertTrue(can_fit)
|
||||||
|
self.assertEqual(msg, "")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckContextBudget(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_no_warning_under_threshold(self):
|
||||||
|
with patch("agent.context_budget._tracker", None):
|
||||||
|
messages = [{"role": "user", "content": "short"}]
|
||||||
|
warning = check_context_budget(messages)
|
||||||
|
self.assertIsNone(warning)
|
||||||
|
|
||||||
|
def test_warning_over_threshold(self):
|
||||||
|
with patch("agent.context_budget._tracker", None):
|
||||||
|
# Create messages that exceed 70% of default 128k context
|
||||||
|
messages = [{"role": "user", "content": "x" * 350000}] # ~87500 tokens
|
||||||
|
warning = check_context_budget(messages)
|
||||||
|
self.assertIsNotNone(warning)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStatusLine(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_green_status(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
line = tracker.get_status_line()
|
||||||
|
self.assertIn("GREEN", line)
|
||||||
|
|
||||||
|
def test_red_status(self):
|
||||||
|
tracker = ContextBudgetTracker(context_limit=10000)
|
||||||
|
tracker.budget.used_tokens = 7600
|
||||||
|
line = tracker.get_status_line()
|
||||||
|
self.assertIn("RED", line)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
101
tests/test_credential_redact.py
Normal file
101
tests/test_credential_redact.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
"""
|
||||||
|
Tests for credential redaction
|
||||||
|
|
||||||
|
Issue: #839
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
from tools.credential_redact import (
|
||||||
|
CredentialRedactor,
|
||||||
|
redact_credentials,
|
||||||
|
redact_tool_output,
|
||||||
|
should_mask_file,
|
||||||
|
mask_sensitive_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCredentialRedaction(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_openai_key(self):
|
||||||
|
text = "api_key=sk-abc123def456ghi789jkl012mno"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
self.assertNotIn("sk-abc123", redacted)
|
||||||
|
|
||||||
|
def test_github_token(self):
|
||||||
|
text = "token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
|
||||||
|
def test_bearer_token(self):
|
||||||
|
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
|
||||||
|
def test_password(self):
|
||||||
|
text = "password: mySecretPassword123"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
|
||||||
|
def test_aws_key(self):
|
||||||
|
text = "AWS_ACCESS_KEY_ID=AKIAIOSFODNN7EXAMPLE"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
|
||||||
|
def test_database_url(self):
|
||||||
|
text = "DATABASE_URL=postgres://user:pass@localhost/db"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreater(count, 0)
|
||||||
|
self.assertIn("REDACTED", redacted)
|
||||||
|
|
||||||
|
def test_clean_text_unchanged(self):
|
||||||
|
text = "Hello world, this is a normal message"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertEqual(count, 0)
|
||||||
|
self.assertEqual(redacted, text)
|
||||||
|
|
||||||
|
def test_multiple_credentials(self):
|
||||||
|
text = "key1=sk-abc123def456ghi789jkl012mno and token: ghp_1234567890abcdef1234567890abcdef12345678"
|
||||||
|
redacted, count = redact_credentials(text)
|
||||||
|
self.assertGreaterEqual(count, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolOutputRedaction(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_redaction_notice(self):
|
||||||
|
output = "Running with key sk-abc123def456ghi789jkl012mno"
|
||||||
|
redacted, notice = redact_tool_output("terminal", output)
|
||||||
|
self.assertIn("REDACTED", notice)
|
||||||
|
self.assertIn("terminal", notice)
|
||||||
|
|
||||||
|
def test_no_notice_when_clean(self):
|
||||||
|
output = "Hello world"
|
||||||
|
redacted, notice = redact_tool_output("terminal", output)
|
||||||
|
self.assertEqual(notice, "")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSensitiveFileMasking(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_env_file_detected(self):
|
||||||
|
self.assertTrue(should_mask_file("/path/to/.env"))
|
||||||
|
self.assertTrue(should_mask_file("/path/to/.env.local"))
|
||||||
|
self.assertTrue(should_mask_file("/path/to/config.yaml"))
|
||||||
|
|
||||||
|
def test_normal_file_not_detected(self):
|
||||||
|
self.assertFalse(should_mask_file("/path/to/readme.md"))
|
||||||
|
self.assertFalse(should_mask_file("/path/to/code.py"))
|
||||||
|
|
||||||
|
def test_mask_env_file(self):
|
||||||
|
content = "API_KEY=sk-abc123\nDATABASE_URL=postgres://u:p@h/d\nNORMAL=value"
|
||||||
|
masked = mask_sensitive_file(content, ".env")
|
||||||
|
self.assertIn("[REDACTED]", masked)
|
||||||
|
self.assertIn("NORMAL=value", masked)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
79
tests/test_crisis_resources.py
Normal file
79
tests/test_crisis_resources.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Tests for 988 Crisis Lifeline integration (#673)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from agent.crisis_resources import (
|
||||||
|
LIFELINE_988,
|
||||||
|
LIFELINE_988_TEXT,
|
||||||
|
LIFELINE_988_CHAT,
|
||||||
|
LIFELINE_988_SPANISH,
|
||||||
|
CRISIS_TEXT_LINE,
|
||||||
|
EMERGENCY_911,
|
||||||
|
ALL_RESOURCES,
|
||||||
|
get_crisis_resources,
|
||||||
|
format_crisis_resources,
|
||||||
|
get_immediate_help_message,
|
||||||
|
CrisisResource,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCrisisResources:
|
||||||
|
def test_988_phone(self):
|
||||||
|
assert "988" in LIFELINE_988.contact
|
||||||
|
assert "24/7" in LIFELINE_988.available
|
||||||
|
|
||||||
|
def test_988_text(self):
|
||||||
|
assert "HOME" in LIFELINE_988_TEXT.contact
|
||||||
|
assert "988" in LIFELINE_988_TEXT.contact
|
||||||
|
|
||||||
|
def test_988_chat(self):
|
||||||
|
assert "988lifeline.org/chat" in LIFELINE_988_CHAT.url
|
||||||
|
|
||||||
|
def test_988_spanish(self):
|
||||||
|
assert "1-888-628-9454" in LIFELINE_988_SPANISH.contact
|
||||||
|
assert LIFELINE_988_SPANISH.language == "Spanish"
|
||||||
|
|
||||||
|
def test_crisis_text_line(self):
|
||||||
|
assert "741741" in CRISIS_TEXT_LINE.contact
|
||||||
|
|
||||||
|
def test_911(self):
|
||||||
|
assert "911" in EMERGENCY_911.contact
|
||||||
|
|
||||||
|
def test_all_resources_not_empty(self):
|
||||||
|
assert len(ALL_RESOURCES) >= 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetResources:
|
||||||
|
def test_returns_all_by_default(self):
|
||||||
|
assert len(get_crisis_resources()) == len(ALL_RESOURCES)
|
||||||
|
|
||||||
|
def test_filter_english(self):
|
||||||
|
english = get_crisis_resources("English")
|
||||||
|
assert all(r.language == "English" for r in english)
|
||||||
|
assert len(english) > 0
|
||||||
|
|
||||||
|
def test_filter_spanish(self):
|
||||||
|
spanish = get_crisis_resources("Spanish")
|
||||||
|
assert len(spanish) >= 1
|
||||||
|
assert all(r.language == "Spanish" for r in spanish)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatting:
|
||||||
|
def test_format_includes_988(self):
|
||||||
|
msg = format_crisis_resources()
|
||||||
|
assert "988" in msg
|
||||||
|
|
||||||
|
def test_format_includes_741741(self):
|
||||||
|
msg = format_crisis_resources()
|
||||||
|
assert "741741" in msg
|
||||||
|
|
||||||
|
def test_format_includes_911(self):
|
||||||
|
msg = format_crisis_resources()
|
||||||
|
assert "911" in msg
|
||||||
|
|
||||||
|
def test_immediate_help_includes_911_first(self):
|
||||||
|
msg = get_immediate_help_message()
|
||||||
|
assert msg.startswith("If you are in immediate danger")
|
||||||
|
|
||||||
|
def test_format_not_empty(self):
|
||||||
|
msg = format_crisis_resources()
|
||||||
|
assert len(msg) > 100
|
||||||
274
tests/test_poka_yoke.py
Normal file
274
tests/test_poka_yoke.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
test_poka_yoke.py — Tests for the tool call validation firewall.
|
||||||
|
|
||||||
|
Covers: unknown tool, bad param type, missing required arg,
|
||||||
|
extra unknown param, enum validation, closest-name suggestion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
|
||||||
|
|
||||||
|
from tools.poka_yoke import (
|
||||||
|
validate_tool_call,
|
||||||
|
_find_closest_name,
|
||||||
|
_validate_type,
|
||||||
|
_truncate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Mock Registry ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class MockEntry:
|
||||||
|
def __init__(self, name, schema):
|
||||||
|
self.name = name
|
||||||
|
self.schema = schema
|
||||||
|
self.toolset = "test"
|
||||||
|
|
||||||
|
|
||||||
|
MOCK_TOOLS = {
|
||||||
|
"read_file": MockEntry("read_file", {
|
||||||
|
"name": "read_file",
|
||||||
|
"description": "Read a file",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string", "description": "File path"},
|
||||||
|
"offset": {"type": "integer", "description": "Start line"},
|
||||||
|
"limit": {"type": "integer", "description": "Max lines"},
|
||||||
|
},
|
||||||
|
"required": ["path"],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
"web_search": MockEntry("web_search", {
|
||||||
|
"name": "web_search",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"},
|
||||||
|
"max_results": {"type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["query"],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
"write_file": MockEntry("write_file", {
|
||||||
|
"name": "write_file",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string"},
|
||||||
|
"content": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["path", "content"],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
"terminal": MockEntry("terminal", {
|
||||||
|
"name": "terminal",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {"type": "string"},
|
||||||
|
"timeout": {"type": "integer"},
|
||||||
|
"background": {"type": "boolean"},
|
||||||
|
},
|
||||||
|
"required": ["command"],
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_registry():
|
||||||
|
"""Create a mock registry."""
|
||||||
|
mock_reg = MagicMock()
|
||||||
|
mock_reg.get_entry = lambda name: MOCK_TOOLS.get(name)
|
||||||
|
mock_reg.get_all_tool_names = lambda: list(MOCK_TOOLS.keys())
|
||||||
|
return mock_reg
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Unknown Tool ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestUnknownTool:
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_unknown_tool_rejected(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = None
|
||||||
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call("nonexistent_tool", {})
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert len(msgs) > 0
|
||||||
|
assert "nonexistent_tool" in msgs[0]
|
||||||
|
assert "Unknown tool" in msgs[0]
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_unknown_tool_lists_available(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = None
|
||||||
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call("foo", {})
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert "read_file" in msgs[0]
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_close_name_suggests_correction(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = None
|
||||||
|
mock_reg.get_all_tool_names.return_value = list(MOCK_TOOLS.keys())
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call("readfile", {})
|
||||||
|
|
||||||
|
assert "read_file" in msgs[0]
|
||||||
|
assert name == "read_file"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Missing Required Args ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestMissingRequired:
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_missing_required_rejected(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call("read_file", {})
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("Missing required" in m for m in msgs)
|
||||||
|
assert any("'path'" in m for m in msgs)
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_multiple_missing_required(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call("write_file", {})
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("'path'" in m for m in msgs)
|
||||||
|
assert any("'content'" in m for m in msgs)
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_required_present_passes(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Type Validation ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestTypeValidation:
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_wrong_type_rejected(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt", "offset": "not_a_number"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("offset" in m and "integer" in m for m in msgs)
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_string_to_int_coercion(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt", "offset": "42"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert params is not None
|
||||||
|
assert params["offset"] == 42
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_boolean_coercion(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["terminal"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"terminal", {"command": "ls", "background": "true"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert params is not None
|
||||||
|
assert params["background"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Unknown Parameters ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestUnknownParams:
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_unknown_param_removed(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt", "bogus_param": "value"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert params is not None
|
||||||
|
assert "bogus_param" not in params
|
||||||
|
assert "path" in params
|
||||||
|
assert any("Unknown parameter" in m for m in msgs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Valid Calls Pass Through ────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestValidCalls:
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_valid_read_file(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["read_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt", "offset": 1, "limit": 100}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert name is None
|
||||||
|
assert params is None
|
||||||
|
assert msgs == []
|
||||||
|
|
||||||
|
@patch("tools.poka_yoke.registry")
|
||||||
|
def test_valid_write_file(self, mock_reg):
|
||||||
|
mock_reg.get_entry.return_value = MOCK_TOOLS["write_file"]
|
||||||
|
|
||||||
|
is_valid, name, params, msgs = validate_tool_call(
|
||||||
|
"write_file", {"path": "out.txt", "content": "hello"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test: Helper Functions ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestHelpers:
|
||||||
|
def test_find_closest_exact_prefix(self):
|
||||||
|
assert _find_closest_name("readfil", ["read_file", "write_file"]) == "read_file"
|
||||||
|
|
||||||
|
def test_find_closest_substring(self):
|
||||||
|
assert _find_closest_name("file", ["read_file", "web_search"]) == "read_file"
|
||||||
|
|
||||||
|
def test_find_closest_no_match(self):
|
||||||
|
assert _find_closest_name("xyzzy", ["read_file", "write_file"]) is None
|
||||||
|
|
||||||
|
def test_validate_type_string(self):
|
||||||
|
ok, val = _validate_type("x", "hello", "string")
|
||||||
|
assert ok is True
|
||||||
|
|
||||||
|
def test_validate_type_int_coercion(self):
|
||||||
|
ok, val = _validate_type("x", "42", "integer")
|
||||||
|
assert ok is True
|
||||||
|
assert val == 42
|
||||||
|
|
||||||
|
def test_validate_type_int_bad(self):
|
||||||
|
ok, val = _validate_type("x", "not_int", "integer")
|
||||||
|
assert ok is False
|
||||||
|
|
||||||
|
def test_truncate(self):
|
||||||
|
assert _truncate("hello", 10) == "hello"
|
||||||
|
assert _truncate("hello world", 8) == "hello..."
|
||||||
76
tests/test_profile_isolation.py
Normal file
76
tests/test_profile_isolation.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""Tests for profile session isolation (#891)."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import json
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
# Override paths for testing
|
||||||
|
import agent.profile_isolation as iso_mod
|
||||||
|
_test_dir = Path(tempfile.mkdtemp())
|
||||||
|
iso_mod.PROFILE_TAGS_FILE = _test_dir / "tags.json"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tag_session():
|
||||||
|
"""Session gets tagged with profile."""
|
||||||
|
profile = iso_mod.tag_session("sess-1", "sprint")
|
||||||
|
assert profile == "sprint"
|
||||||
|
assert iso_mod.get_session_profile("sess-1") == "sprint"
|
||||||
|
|
||||||
|
|
||||||
|
def test_default_profile():
|
||||||
|
"""Sessions tagged with default when no profile specified."""
|
||||||
|
profile = iso_mod.tag_session("sess-2")
|
||||||
|
assert profile is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_session_profile():
|
||||||
|
"""Can retrieve profile for tagged session."""
|
||||||
|
iso_mod.tag_session("sess-3", "fenrir")
|
||||||
|
assert iso_mod.get_session_profile("sess-3") == "fenrir"
|
||||||
|
|
||||||
|
|
||||||
|
def test_untagged_returns_none():
|
||||||
|
"""Untagged session returns None."""
|
||||||
|
assert iso_mod.get_session_profile("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_profile_stats():
|
||||||
|
"""Stats reflect tagged sessions."""
|
||||||
|
iso_mod.tag_session("s1", "default")
|
||||||
|
iso_mod.tag_session("s2", "sprint")
|
||||||
|
iso_mod.tag_session("s3", "sprint")
|
||||||
|
stats = iso_mod.get_profile_stats()
|
||||||
|
assert stats["total_tagged_sessions"] >= 3
|
||||||
|
assert "sprint" in stats["profile_counts"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_sessions():
|
||||||
|
"""Filter returns only matching profile sessions."""
|
||||||
|
iso_mod.tag_session("filter-1", "alpha")
|
||||||
|
iso_mod.tag_session("filter-2", "beta")
|
||||||
|
iso_mod.tag_session("filter-3", "alpha")
|
||||||
|
|
||||||
|
sessions = [
|
||||||
|
{"session_id": "filter-1"},
|
||||||
|
{"session_id": "filter-2"},
|
||||||
|
{"session_id": "filter-3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = iso_mod.filter_sessions_by_profile(sessions, "alpha")
|
||||||
|
ids = [s["session_id"] for s in filtered]
|
||||||
|
assert "filter-1" in ids
|
||||||
|
assert "filter-3" in ids
|
||||||
|
assert "filter-2" not in ids
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tests = [test_tag_session, test_default_profile, test_get_session_profile,
|
||||||
|
test_untagged_returns_none, test_profile_stats, test_filter_sessions]
|
||||||
|
for t in tests:
|
||||||
|
print(f"Running {t.__name__}...")
|
||||||
|
t()
|
||||||
|
print(" PASS")
|
||||||
|
print("\nAll tests passed.")
|
||||||
302
tests/test_skill_manager_autorevert.py
Normal file
302
tests/test_skill_manager_autorevert.py
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for poka-yoke auto-revert on incomplete skill edits (#923).
|
||||||
|
|
||||||
|
Verifies the transactional write-validate-commit-or-rollback pattern:
|
||||||
|
- Backup created before every write
|
||||||
|
- Post-write validation triggers revert on corrupted/empty file
|
||||||
|
- Successful writes clean up the backup
|
||||||
|
- At most MAX_BACKUPS_PER_FILE backups retained per file
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from tools.skill_manager_tool import (
|
||||||
|
MAX_BACKUPS_PER_FILE,
|
||||||
|
_backup_skill_file,
|
||||||
|
_cleanup_old_backups,
|
||||||
|
_edit_skill,
|
||||||
|
_patch_skill,
|
||||||
|
_revert_from_backup,
|
||||||
|
_validate_written_file,
|
||||||
|
_write_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
VALID_SKILL_MD = """\
|
||||||
|
---
|
||||||
|
name: test-skill
|
||||||
|
description: A skill for testing auto-revert
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Test skill body content.
|
||||||
|
"""
|
||||||
|
|
||||||
|
VALID_UPDATED_MD = """\
|
||||||
|
---
|
||||||
|
name: test-skill
|
||||||
|
description: Updated description
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Updated test skill body.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_skill(tmp_path: Path, content: str = VALID_SKILL_MD) -> Path:
|
||||||
|
"""Write a minimal SKILL.md in *tmp_path* and return its path."""
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text(content, encoding="utf-8")
|
||||||
|
return skill_md
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: _backup_skill_file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestBackupSkillFile:
|
||||||
|
def test_creates_bak_file(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
backup = _backup_skill_file(skill_md)
|
||||||
|
assert backup is not None
|
||||||
|
assert backup.exists()
|
||||||
|
assert ".bak." in backup.name
|
||||||
|
|
||||||
|
def test_backup_preserves_content(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
backup = _backup_skill_file(skill_md)
|
||||||
|
assert backup.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||||
|
|
||||||
|
def test_no_backup_for_nonexistent_file(self, tmp_path):
|
||||||
|
missing = tmp_path / "SKILL.md"
|
||||||
|
assert _backup_skill_file(missing) is None
|
||||||
|
|
||||||
|
def test_backup_name_contains_timestamp(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
before = int(time.time())
|
||||||
|
backup = _backup_skill_file(skill_md)
|
||||||
|
after = int(time.time())
|
||||||
|
ts = int(backup.name.split(".bak.")[-1])
|
||||||
|
assert before <= ts <= after
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: _cleanup_old_backups
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestCleanupOldBackups:
|
||||||
|
def _create_backups(self, skill_md: Path, n: int) -> list:
|
||||||
|
backups = []
|
||||||
|
for i in range(n):
|
||||||
|
bp = skill_md.parent / f"{skill_md.name}.bak.{1000 + i}"
|
||||||
|
bp.write_text("backup content", encoding="utf-8")
|
||||||
|
backups.append(bp)
|
||||||
|
return backups
|
||||||
|
|
||||||
|
def test_prunes_excess_backups(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE + 2)
|
||||||
|
_cleanup_old_backups(skill_md)
|
||||||
|
remaining = list(tmp_path.glob(f"SKILL.md.bak.*"))
|
||||||
|
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||||
|
|
||||||
|
def test_keeps_backups_within_limit(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
self._create_backups(skill_md, MAX_BACKUPS_PER_FILE)
|
||||||
|
_cleanup_old_backups(skill_md)
|
||||||
|
remaining = list(tmp_path.glob("SKILL.md.bak.*"))
|
||||||
|
assert len(remaining) == MAX_BACKUPS_PER_FILE
|
||||||
|
|
||||||
|
def test_noop_when_no_backups(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
_cleanup_old_backups(skill_md) # should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: _validate_written_file
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestValidateWrittenFile:
|
||||||
|
def test_valid_skill_md(self, tmp_path):
|
||||||
|
skill_md = _make_skill(tmp_path)
|
||||||
|
assert _validate_written_file(skill_md, is_skill_md=True) is None
|
||||||
|
|
||||||
|
def test_empty_file_fails(self, tmp_path):
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text("", encoding="utf-8")
|
||||||
|
err = _validate_written_file(skill_md, is_skill_md=False)
|
||||||
|
assert err is not None
|
||||||
|
assert "empty" in err.lower()
|
||||||
|
|
||||||
|
def test_broken_frontmatter_fails(self, tmp_path):
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text("Not a skill\nno frontmatter\n", encoding="utf-8")
|
||||||
|
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||||
|
assert err is not None
|
||||||
|
|
||||||
|
def test_missing_required_field_fails(self, tmp_path):
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text("---\ndescription: no name\n---\nbody\n", encoding="utf-8")
|
||||||
|
err = _validate_written_file(skill_md, is_skill_md=True)
|
||||||
|
assert err is not None
|
||||||
|
assert "name" in err.lower()
|
||||||
|
|
||||||
|
def test_missing_file_returns_error(self, tmp_path):
|
||||||
|
missing = tmp_path / "SKILL.md"
|
||||||
|
err = _validate_written_file(missing, is_skill_md=False)
|
||||||
|
assert err is not None
|
||||||
|
|
||||||
|
def test_non_skill_md_only_checks_emptiness(self, tmp_path):
|
||||||
|
ref = tmp_path / "references" / "guide.md"
|
||||||
|
ref.parent.mkdir()
|
||||||
|
ref.write_text("# Guide\nsome content\n", encoding="utf-8")
|
||||||
|
assert _validate_written_file(ref, is_skill_md=False) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Unit tests: _revert_from_backup
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestRevertFromBackup:
|
||||||
|
def test_restores_from_backup(self, tmp_path):
|
||||||
|
original = "original content"
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text(original, encoding="utf-8")
|
||||||
|
backup = tmp_path / "SKILL.md.bak.99999"
|
||||||
|
backup.write_text(original, encoding="utf-8")
|
||||||
|
|
||||||
|
skill_md.write_text("corrupted content", encoding="utf-8")
|
||||||
|
_revert_from_backup(skill_md, backup)
|
||||||
|
assert skill_md.read_text(encoding="utf-8") == original
|
||||||
|
|
||||||
|
def test_removes_file_when_no_backup(self, tmp_path):
|
||||||
|
skill_md = tmp_path / "SKILL.md"
|
||||||
|
skill_md.write_text("corrupted", encoding="utf-8")
|
||||||
|
_revert_from_backup(skill_md, None)
|
||||||
|
assert not skill_md.exists()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests: _edit_skill auto-revert
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEditSkillAutoRevert:
|
||||||
|
@pytest.fixture
|
||||||
|
def skill_dir(self, tmp_path):
|
||||||
|
"""Create a minimal skill directory and patch _find_skill."""
|
||||||
|
d = tmp_path / "test-skill"
|
||||||
|
d.mkdir()
|
||||||
|
skill_md = d / "SKILL.md"
|
||||||
|
skill_md.write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_successful_edit_removes_backup(self, skill_dir):
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||||
|
assert len(backups) == 0
|
||||||
|
|
||||||
|
def test_revert_when_post_write_validation_fails(self, skill_dir):
|
||||||
|
"""Simulate a write that produces an empty file on disk."""
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
|
||||||
|
def corrupt_write(path, content, **kw):
|
||||||
|
# Write an empty file to simulate truncation
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "reverted" in result["error"].lower()
|
||||||
|
# Original content restored
|
||||||
|
assert skill_md.read_text(encoding="utf-8") == VALID_SKILL_MD
|
||||||
|
|
||||||
|
def test_backup_preserved_after_revert(self, skill_dir):
|
||||||
|
"""A .bak file should survive when the edit is reverted (debugging aid)."""
|
||||||
|
def corrupt_write(path, content, **kw):
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
_edit_skill("test-skill", VALID_UPDATED_MD)
|
||||||
|
|
||||||
|
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||||
|
assert len(backups) == 1
|
||||||
|
|
||||||
|
def test_max_backups_enforced_after_multiple_edits(self, skill_dir):
|
||||||
|
"""After many successful edits, at most MAX_BACKUPS_PER_FILE .bak files remain."""
|
||||||
|
n = MAX_BACKUPS_PER_FILE + 4
|
||||||
|
for i in range(n):
|
||||||
|
# Plant stale backup files to simulate prior runs
|
||||||
|
bp = skill_dir / f"SKILL.md.bak.{1000 + i}"
|
||||||
|
bp.write_text("old backup", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
result = _edit_skill("test-skill", VALID_UPDATED_MD)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
backups = list(skill_dir.glob("SKILL.md.bak.*"))
|
||||||
|
assert len(backups) <= MAX_BACKUPS_PER_FILE
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests: _patch_skill auto-revert
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestPatchSkillAutoRevert:
|
||||||
|
@pytest.fixture
|
||||||
|
def skill_dir(self, tmp_path):
|
||||||
|
d = tmp_path / "test-skill"
|
||||||
|
d.mkdir()
|
||||||
|
(d / "SKILL.md").write_text(VALID_SKILL_MD, encoding="utf-8")
|
||||||
|
return d
|
||||||
|
|
||||||
|
def test_successful_patch_removes_backup(self, skill_dir):
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._security_scan_skill", return_value=None):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
result = _patch_skill(
|
||||||
|
"test-skill",
|
||||||
|
"A skill for testing auto-revert",
|
||||||
|
"Updated description",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert len(list(skill_dir.glob("SKILL.md.bak.*"))) == 0
|
||||||
|
|
||||||
|
def test_revert_on_corrupt_write(self, skill_dir):
|
||||||
|
skill_md = skill_dir / "SKILL.md"
|
||||||
|
original = skill_md.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
def corrupt_write(path, content, **kw):
|
||||||
|
path.write_text("", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("tools.skill_manager_tool._find_skill") as mock_find, \
|
||||||
|
patch("tools.skill_manager_tool._atomic_write_text", side_effect=corrupt_write):
|
||||||
|
mock_find.return_value = {"path": skill_dir}
|
||||||
|
result = _patch_skill(
|
||||||
|
"test-skill",
|
||||||
|
"A skill for testing",
|
||||||
|
"A skill for testing auto-revert",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["success"] is False
|
||||||
|
assert "reverted" in result["error"].lower()
|
||||||
|
assert skill_md.read_text(encoding="utf-8") == original
|
||||||
82
tests/test_syntax_validation.py
Normal file
82
tests/test_syntax_validation.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Tests for Python syntax validation in execute_code."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Import the validation function directly
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parents[1]))
|
||||||
|
from tools.code_execution_tool import _validate_python_syntax
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidatePythonSyntax:
|
||||||
|
"""Test _validate_python_syntax catches errors before subprocess spawn."""
|
||||||
|
|
||||||
|
def test_valid_code_returns_none(self):
|
||||||
|
assert _validate_python_syntax("print('hello')") is None
|
||||||
|
|
||||||
|
def test_valid_multiline_returns_none(self):
|
||||||
|
code = """
|
||||||
|
import os
|
||||||
|
def foo():
|
||||||
|
return 42
|
||||||
|
result = foo()
|
||||||
|
"""
|
||||||
|
assert _validate_python_syntax(code) is None
|
||||||
|
|
||||||
|
def test_syntax_error_detected(self):
|
||||||
|
result = _validate_python_syntax("def foo(
|
||||||
|
")
|
||||||
|
assert result is not None
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["syntax_error"] is True
|
||||||
|
assert "line" in data
|
||||||
|
assert "message" in data
|
||||||
|
|
||||||
|
def test_missing_colon(self):
|
||||||
|
result = _validate_python_syntax("def foo()
|
||||||
|
pass")
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["syntax_error"] is True
|
||||||
|
assert data["line"] == 1
|
||||||
|
|
||||||
|
def test_unmatched_paren(self):
|
||||||
|
result = _validate_python_syntax("print('hello'")
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["syntax_error"] is True
|
||||||
|
|
||||||
|
def test_indentation_error(self):
|
||||||
|
result = _validate_python_syntax("def foo():
|
||||||
|
pass")
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["syntax_error"] is True
|
||||||
|
assert data["line"] == 2
|
||||||
|
|
||||||
|
def test_invalid_character(self):
|
||||||
|
result = _validate_python_syntax("x = 5 √ 2")
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["syntax_error"] is True
|
||||||
|
|
||||||
|
def test_error_format_has_required_fields(self):
|
||||||
|
result = _validate_python_syntax("def(
|
||||||
|
")
|
||||||
|
data = json.loads(result)
|
||||||
|
assert "error" in data
|
||||||
|
assert "syntax_error" in data
|
||||||
|
assert "line" in data
|
||||||
|
assert "offset" in data
|
||||||
|
assert "message" in data
|
||||||
|
|
||||||
|
def test_empty_string_returns_none(self):
|
||||||
|
# Empty code is caught by the guard before validation
|
||||||
|
# But if called directly, ast.parse("") is valid
|
||||||
|
assert _validate_python_syntax("") is None
|
||||||
|
|
||||||
|
def test_comment_only_returns_none(self):
|
||||||
|
assert _validate_python_syntax("# just a comment") is None
|
||||||
|
|
||||||
|
def test_complex_valid_code(self):
|
||||||
|
code =
|
||||||
58
tests/test_time_aware_routing.py
Normal file
58
tests/test_time_aware_routing.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
"""Tests for time-aware model routing."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
|
||||||
|
from agent.time_aware_routing import (
|
||||||
|
resolve_time_aware_model,
|
||||||
|
get_hour_error_rate,
|
||||||
|
is_off_hours,
|
||||||
|
get_routing_report,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestErrorRates:
|
||||||
|
def test_evening_high_error(self):
|
||||||
|
assert get_hour_error_rate(18) == 9.4
|
||||||
|
assert get_hour_error_rate(19) == 8.1
|
||||||
|
|
||||||
|
def test_morning_low_error(self):
|
||||||
|
assert get_hour_error_rate(9) == 4.0
|
||||||
|
assert get_hour_error_rate(12) == 4.0
|
||||||
|
|
||||||
|
def test_default_for_unknown(self):
|
||||||
|
assert get_hour_error_rate(15) == 4.0
|
||||||
|
|
||||||
|
|
||||||
|
class TestOffHours:
|
||||||
|
def test_evening_is_off_hours(self):
|
||||||
|
assert is_off_hours(20) is True
|
||||||
|
assert is_off_hours(2) is True
|
||||||
|
|
||||||
|
def test_business_hours_not_off(self):
|
||||||
|
assert is_off_hours(9) is False
|
||||||
|
assert is_off_hours(14) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouting:
|
||||||
|
def test_interactive_uses_base_model(self):
|
||||||
|
d = resolve_time_aware_model("my-model", "my-provider", is_cron=False, hour=18)
|
||||||
|
assert d.model == "my-model"
|
||||||
|
assert "Interactive" in d.reason
|
||||||
|
|
||||||
|
def test_cron_low_error_uses_base(self):
|
||||||
|
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=10)
|
||||||
|
assert d.model == "cheap-model"
|
||||||
|
|
||||||
|
def test_cron_high_error_upgrades(self):
|
||||||
|
d = resolve_time_aware_model("cheap-model", is_cron=True, hour=18)
|
||||||
|
assert d.model != "cheap-model"
|
||||||
|
assert d.is_off_hours is True
|
||||||
|
|
||||||
|
def test_routing_report(self):
|
||||||
|
report = get_routing_report()
|
||||||
|
assert "Time-Aware Model Routing" in report
|
||||||
|
assert "18:00" in report
|
||||||
237
tests/test_token_budget.py
Normal file
237
tests/test_token_budget.py
Normal file
@@ -0,0 +1,237 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for agent/token_budget.py — Poka-yoke context overflow guard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
from agent.token_budget import (
|
||||||
|
TokenBudget,
|
||||||
|
BudgetLevel,
|
||||||
|
BudgetStatus,
|
||||||
|
WARN_PERCENT,
|
||||||
|
CAUTION_PERCENT,
|
||||||
|
CRITICAL_PERCENT,
|
||||||
|
STOP_PERCENT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def budget():
|
||||||
|
"""Standard 128K context budget."""
|
||||||
|
return TokenBudget(context_length=128_000)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def small_budget():
|
||||||
|
"""4K context for tight testing."""
|
||||||
|
return TokenBudget(context_length=4_000)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threshold Levels ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestThresholds:
|
||||||
|
def test_normal_below_60(self, budget):
|
||||||
|
budget.update(50_000) # 39%
|
||||||
|
status = budget.check()
|
||||||
|
assert status.level == BudgetLevel.NORMAL
|
||||||
|
assert not status.should_compress
|
||||||
|
assert not status.should_block_tools
|
||||||
|
assert not status.should_terminate
|
||||||
|
|
||||||
|
def test_warning_at_60(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.62)) # 62%
|
||||||
|
status = budget.check()
|
||||||
|
assert status.level == BudgetLevel.WARNING
|
||||||
|
assert not status.should_compress
|
||||||
|
assert not status.should_block_tools
|
||||||
|
|
||||||
|
def test_caution_at_80(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.82)) # 82%
|
||||||
|
status = budget.check()
|
||||||
|
assert status.level == BudgetLevel.CAUTION
|
||||||
|
assert status.should_compress
|
||||||
|
assert not status.should_block_tools
|
||||||
|
assert not status.should_terminate
|
||||||
|
|
||||||
|
def test_critical_at_90(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.91)) # 91%
|
||||||
|
status = budget.check()
|
||||||
|
assert status.level == BudgetLevel.CRITICAL
|
||||||
|
assert status.should_compress
|
||||||
|
assert status.should_block_tools
|
||||||
|
assert not status.should_terminate
|
||||||
|
|
||||||
|
def test_stop_at_95(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.96)) # 96%
|
||||||
|
status = budget.check()
|
||||||
|
assert status.level == BudgetLevel.STOP
|
||||||
|
assert status.should_compress
|
||||||
|
assert status.should_block_tools
|
||||||
|
assert status.should_terminate
|
||||||
|
|
||||||
|
def test_small_context_thresholds(self, small_budget):
|
||||||
|
# 4K * 0.60 = 2400
|
||||||
|
small_budget.update(2450)
|
||||||
|
assert small_budget.check().level == BudgetLevel.WARNING
|
||||||
|
|
||||||
|
small_budget.update(3250) # 4K * 0.81
|
||||||
|
assert small_budget.check().level == BudgetLevel.CAUTION
|
||||||
|
|
||||||
|
small_budget.update(3650) # 4K * 0.91
|
||||||
|
assert small_budget.check().level == BudgetLevel.CRITICAL
|
||||||
|
|
||||||
|
small_budget.update(3850) # 4K * 0.96
|
||||||
|
assert small_budget.check().level == BudgetLevel.STOP
|
||||||
|
|
||||||
|
|
||||||
|
# ── Convenience Methods ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestConvenienceMethods:
|
||||||
|
def test_should_compress(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.79))
|
||||||
|
assert not budget.should_compress()
|
||||||
|
budget.update(int(128_000 * 0.80))
|
||||||
|
assert budget.should_compress()
|
||||||
|
|
||||||
|
def test_should_block_tools(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.89))
|
||||||
|
assert not budget.should_block_tools()
|
||||||
|
budget.update(int(128_000 * 0.90))
|
||||||
|
assert budget.should_block_tools()
|
||||||
|
|
||||||
|
def test_should_terminate(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.94))
|
||||||
|
assert not budget.should_terminate()
|
||||||
|
budget.update(int(128_000 * 0.95))
|
||||||
|
assert budget.should_terminate()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool Output Budgeting ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestToolOutputBudget:
|
||||||
|
def test_normal_budget(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.50))
|
||||||
|
assert budget.tool_output_budget() == 50_000
|
||||||
|
|
||||||
|
def test_warning_budget(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.65))
|
||||||
|
assert budget.tool_output_budget() == 20_000
|
||||||
|
|
||||||
|
def test_caution_budget(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.85))
|
||||||
|
assert budget.tool_output_budget() == 8_000
|
||||||
|
|
||||||
|
def test_critical_budget(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.92))
|
||||||
|
assert budget.tool_output_budget() == 2_000
|
||||||
|
|
||||||
|
def test_truncate_short_unchanged(self, budget):
|
||||||
|
result = budget.truncate_tool_output("short text", max_chars=1000)
|
||||||
|
assert result == "short text"
|
||||||
|
|
||||||
|
def test_truncate_long(self, budget):
|
||||||
|
long_text = "A" * 100_000
|
||||||
|
result = budget.truncate_tool_output(long_text, max_chars=5_000)
|
||||||
|
assert len(result) <= 5_100 # small overhead for notice
|
||||||
|
assert "truncated" in result
|
||||||
|
assert "A" in result[:2500] # head preserved
|
||||||
|
assert "A" in result[-2500:] # tail preserved
|
||||||
|
|
||||||
|
def test_truncate_very_small(self, budget):
|
||||||
|
long_text = "X" * 1000
|
||||||
|
result = budget.truncate_tool_output(long_text, max_chars=50)
|
||||||
|
assert len(result) <= 50 + 20
|
||||||
|
assert "truncated" in result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Growth Tracking ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestGrowthTracking:
|
||||||
|
def test_growth_rate(self, budget):
|
||||||
|
budget.update(10_000)
|
||||||
|
budget.update(15_000)
|
||||||
|
budget.update(20_000)
|
||||||
|
assert budget.growth_rate() == 5_000.0
|
||||||
|
|
||||||
|
def test_turns_remaining(self, budget):
|
||||||
|
budget.update(10_000)
|
||||||
|
budget.update(15_000)
|
||||||
|
budget.update(20_000)
|
||||||
|
# rate=5000, remaining=108000, turns=~21
|
||||||
|
turns = budget.turns_remaining()
|
||||||
|
assert turns is not None
|
||||||
|
assert 18 <= turns <= 24
|
||||||
|
|
||||||
|
def test_no_history(self, budget):
|
||||||
|
assert budget.growth_rate() is None
|
||||||
|
assert budget.turns_remaining() is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Status Indicators ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestStatusIndicators:
|
||||||
|
def test_indicator_normal(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.50))
|
||||||
|
status = budget.check()
|
||||||
|
indicator = status.to_indicator()
|
||||||
|
assert "50" in indicator
|
||||||
|
|
||||||
|
def test_indicator_warning(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.65))
|
||||||
|
status = budget.check()
|
||||||
|
indicator = status.to_indicator()
|
||||||
|
assert "\u26a0" in indicator or "65" in indicator
|
||||||
|
|
||||||
|
def test_bar(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.50))
|
||||||
|
status = budget.check()
|
||||||
|
bar = status.to_bar()
|
||||||
|
assert "50" in bar
|
||||||
|
|
||||||
|
def test_summary(self, budget):
|
||||||
|
budget.update(50_000)
|
||||||
|
summary = budget.summary()
|
||||||
|
assert "50,000" in summary
|
||||||
|
assert "128,000" in summary
|
||||||
|
assert "NORMAL" in summary
|
||||||
|
|
||||||
|
|
||||||
|
# ── Reset ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestReset:
|
||||||
|
def test_reset_clears_state(self, budget):
|
||||||
|
budget.update(int(128_000 * 0.90))
|
||||||
|
budget.reset()
|
||||||
|
assert budget.tokens_used == 0
|
||||||
|
assert budget.check().level == BudgetLevel.NORMAL
|
||||||
|
assert budget.growth_rate() is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge Cases ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestEdgeCases:
|
||||||
|
def test_exact_threshold_boundary(self, budget):
|
||||||
|
# Exactly at 60%
|
||||||
|
budget.update(int(128_000 * 0.60))
|
||||||
|
assert budget.check().level == BudgetLevel.WARNING
|
||||||
|
|
||||||
|
def test_zero_context(self):
|
||||||
|
budget = TokenBudget(context_length=0)
|
||||||
|
status = budget.check()
|
||||||
|
assert status.percent_used == 0
|
||||||
|
|
||||||
|
def test_remaining_for_response(self, budget):
|
||||||
|
budget.update(100_000)
|
||||||
|
remaining = budget.remaining_for_response()
|
||||||
|
# 128000 - 100000 - 6400 (5% reserve) = 21600
|
||||||
|
assert remaining > 0
|
||||||
|
assert remaining < 128_000
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v"])
|
||||||
67
tests/test_tool_validator.py
Normal file
67
tests/test_tool_validator.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Tests for tool hallucination detection (#922).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from tools.tool_validator import ToolHallucinationDetector, ValidationSeverity
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolHallucinationDetector:
|
||||||
|
def setup_method(self):
|
||||||
|
self.detector = ToolHallucinationDetector()
|
||||||
|
self.detector.register_tool("read_file", {
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string"},
|
||||||
|
"encoding": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["path"]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
def test_valid_tool_call(self):
|
||||||
|
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt"})
|
||||||
|
assert result.valid is True
|
||||||
|
assert len(result.blocking_issues) == 0
|
||||||
|
|
||||||
|
def test_unknown_tool(self):
|
||||||
|
result = self.detector.validate_tool_call("hallucinated_tool", {})
|
||||||
|
assert result.valid is False
|
||||||
|
assert any(i.code == "UNKNOWN_TOOL" for i in result.issues)
|
||||||
|
|
||||||
|
def test_missing_required_param(self):
|
||||||
|
result = self.detector.validate_tool_call("read_file", {})
|
||||||
|
assert result.valid is False
|
||||||
|
assert any(i.code == "MISSING_REQUIRED" for i in result.issues)
|
||||||
|
|
||||||
|
def test_wrong_type(self):
|
||||||
|
result = self.detector.validate_tool_call("read_file", {"path": 123})
|
||||||
|
assert result.valid is False
|
||||||
|
assert any(i.code == "WRONG_TYPE" for i in result.issues)
|
||||||
|
|
||||||
|
def test_unknown_param_warning(self):
|
||||||
|
result = self.detector.validate_tool_call("read_file", {"path": "/tmp/file.txt", "unknown": "value"})
|
||||||
|
assert result.valid is True # Warning, not blocking
|
||||||
|
assert any(i.code == "UNKNOWN_PARAM" for i in result.issues)
|
||||||
|
|
||||||
|
def test_placeholder_detection(self):
|
||||||
|
result = self.detector.validate_tool_call("read_file", {"path": "<placeholder>"})
|
||||||
|
assert any(i.code == "PLACEHOLDER_VALUE" for i in result.issues)
|
||||||
|
|
||||||
|
def test_rejection_stats(self):
|
||||||
|
self.detector.validate_tool_call("unknown_tool", {})
|
||||||
|
self.detector.validate_tool_call("read_file", {})
|
||||||
|
stats = self.detector.get_rejection_stats()
|
||||||
|
assert stats["total"] >= 2
|
||||||
|
|
||||||
|
def test_rejection_response(self):
|
||||||
|
from tools.tool_validator import create_rejection_response
|
||||||
|
result = self.detector.validate_tool_call("unknown_tool", {})
|
||||||
|
response = create_rejection_response(result)
|
||||||
|
assert response["role"] == "tool"
|
||||||
|
assert "rejected" in response["content"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__])
|
||||||
@@ -28,6 +28,7 @@ Platform: Linux / macOS only (Unix domain sockets for local). Disabled on Window
|
|||||||
Remote execution additionally requires Python 3 in the terminal backend.
|
Remote execution additionally requires Python 3 in the terminal backend.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
@@ -883,6 +884,42 @@ def _execute_remote(
|
|||||||
return json.dumps(result, ensure_ascii=False)
|
return json.dumps(result, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_python_syntax(code: str) -> Optional[str]:
|
||||||
|
"""Validate Python syntax before subprocess spawn.
|
||||||
|
|
||||||
|
Runs ast.parse() in-process (sub-millisecond) to catch syntax errors
|
||||||
|
before wasting time spawning a sandboxed subprocess.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON error string with line, offset, message if syntax is invalid.
|
||||||
|
None if syntax is valid.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
ast.parse(code)
|
||||||
|
return None
|
||||||
|
except SyntaxError as exc:
|
||||||
|
# Build context: show offending line with caret
|
||||||
|
lines = code.split("\n")
|
||||||
|
error_line = lines[exc.lineno - 1] if exc.lineno and exc.lineno <= len(lines) else ""
|
||||||
|
context = ""
|
||||||
|
if error_line:
|
||||||
|
context = f"\n {error_line}"
|
||||||
|
if exc.offset:
|
||||||
|
context += f"\n {' ' * (exc.offset - 1)}^"
|
||||||
|
|
||||||
|
return json.dumps({
|
||||||
|
"error": f"Python syntax error on line {exc.lineno}: {exc.msg}{context}",
|
||||||
|
"syntax_error": True,
|
||||||
|
"line": exc.lineno,
|
||||||
|
"offset": exc.offset,
|
||||||
|
"message": exc.msg,
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Main entry point
|
# Main entry point
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -916,6 +953,11 @@ def execute_code(
|
|||||||
if not code or not code.strip():
|
if not code or not code.strip():
|
||||||
return tool_error("No code provided.")
|
return tool_error("No code provided.")
|
||||||
|
|
||||||
|
# Syntax check before subprocess spawn (catches ~15% of errors in <1ms)
|
||||||
|
syntax_error = _validate_python_syntax(code)
|
||||||
|
if syntax_error:
|
||||||
|
return syntax_error
|
||||||
|
|
||||||
# Dispatch: remote backends use file-based RPC, local uses UDS
|
# Dispatch: remote backends use file-based RPC, local uses UDS
|
||||||
from tools.terminal_tool import _get_env_config
|
from tools.terminal_tool import _get_env_config
|
||||||
env_type = _get_env_config()["env_type"]
|
env_type = _get_env_config()["env_type"]
|
||||||
|
|||||||
183
tools/credential_redact.py
Normal file
183
tools/credential_redact.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""
|
||||||
|
Credential Redaction — Block silent credential exposure in tool outputs
|
||||||
|
|
||||||
|
Poka-yoke: Prevent API keys, tokens, passwords from leaking into context.
|
||||||
|
|
||||||
|
Issue: #839
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
HERMES_HOME = Path.home() / ".hermes"
|
||||||
|
AUDIT_DIR = HERMES_HOME / "audit"
|
||||||
|
|
||||||
|
# Credential patterns to detect and redact
|
||||||
|
CREDENTIAL_PATTERNS = [
|
||||||
|
# API keys
|
||||||
|
(r"sk-[a-zA-Z0-9]{20,}", "[REDACTED: OpenAI API key]"),
|
||||||
|
(r"sk-ant-[a-zA-Z0-9-]{20,}", "[REDACTED: Anthropic API key]"),
|
||||||
|
(r"ghp_[a-zA-Z0-9]{36}", "[REDACTED: GitHub token]"),
|
||||||
|
(r"gho_[a-zA-Z0-9]{36}", "[REDACTED: GitHub OAuth token]"),
|
||||||
|
(r"glpat-[a-zA-Z0-9-]{20,}", "[REDACTED: GitLab token]"),
|
||||||
|
|
||||||
|
# Bearer tokens
|
||||||
|
(r"Bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||||
|
(r"bearer\s+[a-zA-Z0-9._-]{20,}", "[REDACTED: Bearer token]"),
|
||||||
|
|
||||||
|
# Generic tokens/passwords
|
||||||
|
(r"(?:token|TOKEN|Token)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Token]"),
|
||||||
|
(r"(?:password|PASSWORD|Password)[:=]\s*["']?[^\s"']{8,}["']?", "[REDACTED: Password]"),
|
||||||
|
(r"(?:secret|SECRET|Secret)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: Secret]"),
|
||||||
|
(r"(?:api_key|API_KEY|apiKey|ApiKey)[:=]\s*["']?[a-zA-Z0-9._-]{20,}["']?", "[REDACTED: API key]"),
|
||||||
|
|
||||||
|
# AWS keys
|
||||||
|
(r"AKIA[0-9A-Z]{16}", "[REDACTED: AWS access key]"),
|
||||||
|
(r"(?:aws_secret_access_key|AWS_SECRET_ACCESS_KEY)[:=]\s*["']?[a-zA-Z0-9/+=]{40}["']?", "[REDACTED: AWS secret]"),
|
||||||
|
|
||||||
|
# Private keys
|
||||||
|
(r"-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----", "[REDACTED: Private key header]"),
|
||||||
|
|
||||||
|
# Connection strings
|
||||||
|
(r"(?:postgres|mysql|mongodb|redis)://[^:]+:[^@]+@[^\s]+", "[REDACTED: Database connection string]"),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Files that should trigger auto-masking
|
||||||
|
SENSITIVE_FILE_PATTERNS = [
|
||||||
|
r"\.env$",
|
||||||
|
r"\.env\.",
|
||||||
|
r"\.secret",
|
||||||
|
r"credentials",
|
||||||
|
r"\.token",
|
||||||
|
r"config\.yaml$",
|
||||||
|
r"config\.yml$",
|
||||||
|
r"config\.json$",
|
||||||
|
r"\.netrc$",
|
||||||
|
r"\.pgpass$",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class CredentialRedactor:
|
||||||
|
"""Redact credentials from text."""
|
||||||
|
|
||||||
|
def __init__(self, audit_log: bool = True):
|
||||||
|
self.audit_log = audit_log
|
||||||
|
self._redaction_count = 0
|
||||||
|
|
||||||
|
def redact(self, text: str) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Redact credentials from text.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (redacted_text, number_of_redactions)
|
||||||
|
"""
|
||||||
|
if not text:
|
||||||
|
return text, 0
|
||||||
|
|
||||||
|
redacted = text
|
||||||
|
count = 0
|
||||||
|
|
||||||
|
for pattern, replacement in CREDENTIAL_PATTERNS:
|
||||||
|
matches = re.findall(pattern, redacted, re.IGNORECASE)
|
||||||
|
if matches:
|
||||||
|
redacted = re.sub(pattern, replacement, redacted, flags=re.IGNORECASE)
|
||||||
|
count += len(matches)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
self._redaction_count += count
|
||||||
|
if self.audit_log:
|
||||||
|
self._log_redaction(count, text[:100])
|
||||||
|
|
||||||
|
return redacted, count
|
||||||
|
|
||||||
|
def redact_tool_output(self, tool_name: str, output: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Redact tool output and return notice if redactions occurred.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (redacted_output, notice_or_empty)
|
||||||
|
"""
|
||||||
|
redacted, count = self.redact(output)
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
notice = f"[REDACTED: {count} credential pattern{'s' if count > 1 else ''} found in {tool_name} output]"
|
||||||
|
return redacted, notice
|
||||||
|
|
||||||
|
return redacted, ""
|
||||||
|
|
||||||
|
def should_mask_file(self, file_path: str) -> bool:
|
||||||
|
"""Check if file should have credentials auto-masked."""
|
||||||
|
path_lower = file_path.lower()
|
||||||
|
return any(re.search(p, path_lower) for p in SENSITIVE_FILE_PATTERNS)
|
||||||
|
|
||||||
|
def mask_file_content(self, content: str, file_path: str) -> str:
|
||||||
|
"""Mask credentials in file content while preserving structure."""
|
||||||
|
if not self.should_mask_file(file_path):
|
||||||
|
return content
|
||||||
|
|
||||||
|
lines = content.split("\n")
|
||||||
|
masked_lines = []
|
||||||
|
|
||||||
|
for line in lines:
|
||||||
|
# Preserve key=value structure but mask values
|
||||||
|
if "=" in line and not line.strip().startswith("#"):
|
||||||
|
key, _, value = line.partition("=")
|
||||||
|
key_lower = key.strip().lower()
|
||||||
|
|
||||||
|
sensitive_keys = ["password", "secret", "token", "key", "api", "credential"]
|
||||||
|
if any(sk in key_lower for sk in sensitive_keys):
|
||||||
|
masked_lines.append(f"{key}=[REDACTED]")
|
||||||
|
else:
|
||||||
|
masked_lines.append(line)
|
||||||
|
else:
|
||||||
|
masked_lines.append(line)
|
||||||
|
|
||||||
|
return "\n".join(masked_lines)
|
||||||
|
|
||||||
|
def _log_redaction(self, count: int, preview: str):
|
||||||
|
"""Log redaction event to audit trail."""
|
||||||
|
try:
|
||||||
|
AUDIT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
audit_file = AUDIT_DIR / "redactions.jsonl"
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
"redactions": count,
|
||||||
|
"preview_hash": hash(preview),
|
||||||
|
}
|
||||||
|
|
||||||
|
with open(audit_file, "a") as f:
|
||||||
|
f.write(json.dumps(entry) + "\n")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Audit log failed: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level redactor
|
||||||
|
_redactor = CredentialRedactor()
|
||||||
|
|
||||||
|
|
||||||
|
def redact_credentials(text: str) -> Tuple[str, int]:
|
||||||
|
"""Redact credentials from text."""
|
||||||
|
return _redactor.redact(text)
|
||||||
|
|
||||||
|
|
||||||
|
def redact_tool_output(tool_name: str, output: str) -> Tuple[str, str]:
|
||||||
|
"""Redact tool output and return notice."""
|
||||||
|
return _redactor.redact_tool_output(tool_name, output)
|
||||||
|
|
||||||
|
|
||||||
|
def should_mask_file(file_path: str) -> bool:
|
||||||
|
"""Check if file should be masked."""
|
||||||
|
return _redactor.should_mask_file(file_path)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_sensitive_file(content: str, file_path: str) -> str:
|
||||||
|
"""Mask credentials in sensitive file."""
|
||||||
|
return _redactor.mask_file_content(content, file_path)
|
||||||
@@ -327,6 +327,33 @@ def read_file_tool(path: str, offset: int = 1, limit: int = 500, task_id: str =
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ── Path existence guard (poka-yoke #887) ─────────────────────
|
||||||
|
# Check if file exists before attempting read. 83.7% of read_file
|
||||||
|
# errors are file-not-found — the agent hallucinates paths.
|
||||||
|
# This guard catches them early with a clear, actionable error.
|
||||||
|
if not _resolved.exists():
|
||||||
|
# Try to suggest similar files in the same directory
|
||||||
|
parent = _resolved.parent
|
||||||
|
suggestion = ""
|
||||||
|
if parent.exists() and parent.is_dir():
|
||||||
|
similar = [
|
||||||
|
f.name for f in parent.iterdir()
|
||||||
|
if f.is_file() and _resolved.stem[:3].lower() in f.stem.lower()
|
||||||
|
][:5]
|
||||||
|
if similar:
|
||||||
|
suggestion = f" Similar files in {parent}: {', '.join(similar)}"
|
||||||
|
return json.dumps({
|
||||||
|
"error": (
|
||||||
|
f"File not found: '{path}'. The file does not exist at the resolved path "
|
||||||
|
f"({_resolved}).{suggestion} "
|
||||||
|
"Use search_files to find the correct path first."
|
||||||
|
),
|
||||||
|
"path": path,
|
||||||
|
"resolved": str(_resolved),
|
||||||
|
"suggestion": "Use search_files(pattern='...', target='files') to find files.",
|
||||||
|
})
|
||||||
|
|
||||||
# ── Dedup check ───────────────────────────────────────────────
|
# ── Dedup check ───────────────────────────────────────────────
|
||||||
# If we already read this exact (path, offset, limit) and the
|
# If we already read this exact (path, offset, limit) and the
|
||||||
# file hasn't been modified since, return a lightweight stub
|
# file hasn't been modified since, return a lightweight stub
|
||||||
|
|||||||
298
tools/poka_yoke.py
Normal file
298
tools/poka_yoke.py
Normal file
@@ -0,0 +1,298 @@
|
|||||||
|
"""
|
||||||
|
poka_yoke.py — Validation firewall for tool calls.
|
||||||
|
|
||||||
|
Poka-yoke (mistake-proofing): validates tool calls against the registry
|
||||||
|
before execution. Catches hallucinated tool names, malformed parameters,
|
||||||
|
missing required arguments, and type mismatches.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from tools.poka_yoke import validate_tool_call
|
||||||
|
|
||||||
|
is_valid, corrected_name, corrected_params, messages = validate_tool_call(
|
||||||
|
"read_file", {"path": "test.txt"}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_tool_call(
|
||||||
|
function_name: str,
|
||||||
|
function_args: Dict[str, Any],
|
||||||
|
) -> Tuple[bool, Optional[str], Optional[Dict[str, Any]], List[str]]:
|
||||||
|
"""Validate a tool call against the registry before execution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
function_name: The tool name from the LLM's function_call.
|
||||||
|
function_args: The arguments dict from the LLM's function_call.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_valid, corrected_name, corrected_params, messages):
|
||||||
|
- is_valid: False if the call should be blocked entirely.
|
||||||
|
- corrected_name: Suggested name if a close match was found (None if OK).
|
||||||
|
- corrected_params: Corrected params if type coercion fixed issues (None if OK).
|
||||||
|
- messages: List of error/warning/info messages.
|
||||||
|
"""
|
||||||
|
from tools.registry import registry
|
||||||
|
|
||||||
|
messages: List[str] = []
|
||||||
|
corrected_name: Optional[str] = None
|
||||||
|
corrected_params: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
# ── 1. Check if tool exists ───────────────────────────────────────────
|
||||||
|
|
||||||
|
entry = registry.get_entry(function_name)
|
||||||
|
|
||||||
|
if entry is None:
|
||||||
|
# Tool not found — suggest closest match
|
||||||
|
all_names = registry.get_all_tool_names()
|
||||||
|
suggestion = _find_closest_name(function_name, all_names)
|
||||||
|
|
||||||
|
if suggestion:
|
||||||
|
messages.append(
|
||||||
|
f"Unknown tool '{function_name}'. Did you mean '{suggestion}'?"
|
||||||
|
)
|
||||||
|
corrected_name = suggestion
|
||||||
|
# Re-validate with corrected name
|
||||||
|
entry = registry.get_entry(suggestion)
|
||||||
|
if entry is None:
|
||||||
|
return False, corrected_name, None, messages
|
||||||
|
else:
|
||||||
|
available = ", ".join(sorted(all_names)[:20])
|
||||||
|
messages.append(
|
||||||
|
f"Unknown tool '{function_name}'. "
|
||||||
|
f"Available tools: {available}{'...' if len(all_names) > 20 else ''}"
|
||||||
|
)
|
||||||
|
return False, None, None, messages
|
||||||
|
|
||||||
|
# ── 2. Validate parameters against schema ─────────────────────────────
|
||||||
|
|
||||||
|
schema = entry.schema
|
||||||
|
params_schema = schema.get("parameters", {})
|
||||||
|
properties = params_schema.get("properties", {})
|
||||||
|
required = set(params_schema.get("required", []))
|
||||||
|
|
||||||
|
# Check for missing required parameters
|
||||||
|
for param_name in sorted(required):
|
||||||
|
if param_name not in function_args:
|
||||||
|
param_info = properties.get(param_name, {})
|
||||||
|
param_type = param_info.get("type", "any")
|
||||||
|
messages.append(
|
||||||
|
f"Missing required parameter '{param_name}' "
|
||||||
|
f"(expected type: {param_type}). "
|
||||||
|
f"Tool: {function_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If required params are missing, we still return the error
|
||||||
|
# (the agent might be able to self-correct)
|
||||||
|
if any("Missing required" in m for m in messages):
|
||||||
|
# Don't block — return the error as a tool result so the agent can retry
|
||||||
|
# But mark as invalid so caller knows
|
||||||
|
return False, corrected_name, corrected_params, messages
|
||||||
|
|
||||||
|
# ── 3. Check for unknown parameters ───────────────────────────────────
|
||||||
|
|
||||||
|
if properties:
|
||||||
|
known_params = set(properties.keys())
|
||||||
|
# Allow extra params that start with _ (internal convention)
|
||||||
|
unknown = [
|
||||||
|
p for p in function_args
|
||||||
|
if p not in known_params and not p.startswith("_")
|
||||||
|
]
|
||||||
|
if unknown:
|
||||||
|
known_str = ", ".join(sorted(known_params))
|
||||||
|
unknown_str = ", ".join(sorted(unknown))
|
||||||
|
messages.append(
|
||||||
|
f"Unknown parameter(s) for '{function_name}': {unknown_str}. "
|
||||||
|
f"Known parameters: {known_str}"
|
||||||
|
)
|
||||||
|
# Remove unknown params (don't block, just clean)
|
||||||
|
corrected_params = {
|
||||||
|
k: v for k, v in function_args.items()
|
||||||
|
if k in known_params or k.startswith("_")
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── 4. Type validation ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
type_errors = []
|
||||||
|
coerced = dict(corrected_params or function_args)
|
||||||
|
|
||||||
|
for param_name, param_value in coerced.items():
|
||||||
|
if param_name.startswith("_"):
|
||||||
|
continue
|
||||||
|
param_schema = properties.get(param_name)
|
||||||
|
if not param_schema:
|
||||||
|
continue
|
||||||
|
|
||||||
|
expected_type = param_schema.get("type")
|
||||||
|
if not expected_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_valid_type, coerced_value = _validate_type(
|
||||||
|
param_name, param_value, expected_type
|
||||||
|
)
|
||||||
|
if not is_valid_type:
|
||||||
|
type_errors.append(
|
||||||
|
f"Parameter '{param_name}': expected {expected_type}, "
|
||||||
|
f"got {type(param_value).__name__} ({_truncate(str(param_value), 50)})"
|
||||||
|
)
|
||||||
|
elif coerced_value is not param_value:
|
||||||
|
coerced[param_name] = coerced_value
|
||||||
|
messages.append(
|
||||||
|
f"Parameter '{param_name}': coerced from "
|
||||||
|
f"{type(param_value).__name__} to {expected_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if type_errors:
|
||||||
|
messages.extend(type_errors)
|
||||||
|
return False, corrected_name, corrected_params, messages
|
||||||
|
|
||||||
|
if coerced != (corrected_params or function_args):
|
||||||
|
corrected_params = coerced
|
||||||
|
|
||||||
|
# ── 5. Enum validation ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
for param_name, param_value in (corrected_params or function_args).items():
|
||||||
|
param_schema = properties.get(param_name, {})
|
||||||
|
enum_values = param_schema.get("enum")
|
||||||
|
if enum_values and param_value not in enum_values:
|
||||||
|
messages.append(
|
||||||
|
f"Parameter '{param_name}': value '{param_value}' not in "
|
||||||
|
f"allowed values: {enum_values}"
|
||||||
|
)
|
||||||
|
return False, corrected_name, corrected_params, messages
|
||||||
|
|
||||||
|
# ── 6. Pattern validation ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
for param_name, param_value in (corrected_params or function_args).items():
|
||||||
|
if not isinstance(param_value, str):
|
||||||
|
continue
|
||||||
|
param_schema = properties.get(param_name, {})
|
||||||
|
pattern = param_schema.get("pattern")
|
||||||
|
if pattern and not re.match(pattern, param_value):
|
||||||
|
messages.append(
|
||||||
|
f"Parameter '{param_name}': value '{_truncate(param_value, 50)}' "
|
||||||
|
f"does not match pattern '{pattern}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Done ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
is_valid = not any("Missing required" in m for m in messages)
|
||||||
|
|
||||||
|
if is_valid and not messages:
|
||||||
|
return True, None, None, []
|
||||||
|
|
||||||
|
return is_valid, corrected_name, corrected_params, messages
|
||||||
|
|
||||||
|
|
||||||
|
def _find_closest_name(target: str, candidates: List[str]) -> Optional[str]:
|
||||||
|
"""Find the closest tool name using simple edit distance heuristics."""
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
|
||||||
|
target_lower = target.lower()
|
||||||
|
|
||||||
|
# Exact prefix match
|
||||||
|
for name in candidates:
|
||||||
|
if name.lower().startswith(target_lower[:4]) and len(target_lower) > 3:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# Substring match
|
||||||
|
for name in candidates:
|
||||||
|
if target_lower in name.lower() or name.lower() in target_lower:
|
||||||
|
return name
|
||||||
|
|
||||||
|
# Levenshtein distance (simple, for short strings)
|
||||||
|
def _levenshtein(a: str, b: str) -> int:
|
||||||
|
if len(a) < len(b):
|
||||||
|
return _levenshtein(b, a)
|
||||||
|
if len(b) == 0:
|
||||||
|
return len(a)
|
||||||
|
prev = range(len(b) + 1)
|
||||||
|
for i, ca in enumerate(a):
|
||||||
|
curr = [i + 1]
|
||||||
|
for j, cb in enumerate(b):
|
||||||
|
curr.append(min(
|
||||||
|
prev[j + 1] + 1,
|
||||||
|
curr[j] + 1,
|
||||||
|
prev[j] + (0 if ca == cb else 1),
|
||||||
|
))
|
||||||
|
prev = curr
|
||||||
|
return prev[-1]
|
||||||
|
|
||||||
|
distances = [(name, _levenshtein(target_lower, name.lower())) for name in candidates]
|
||||||
|
distances.sort(key=lambda x: x[1])
|
||||||
|
|
||||||
|
# Return if edit distance is small enough
|
||||||
|
if distances and distances[0][1] <= max(3, len(target) // 3):
|
||||||
|
return distances[0][0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_type(
|
||||||
|
param_name: str, value: Any, expected_type: str
|
||||||
|
) -> Tuple[bool, Any]:
|
||||||
|
"""Validate and optionally coerce a parameter value to the expected type.
|
||||||
|
|
||||||
|
Returns (is_valid, coerced_value). coerced_value is value itself if no
|
||||||
|
coercion was needed.
|
||||||
|
"""
|
||||||
|
type_map = {
|
||||||
|
"string": str,
|
||||||
|
"integer": int,
|
||||||
|
"number": (int, float),
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = type_map.get(expected_type)
|
||||||
|
if expected is None:
|
||||||
|
return True, value # Unknown type, skip validation
|
||||||
|
|
||||||
|
# Direct type check
|
||||||
|
if isinstance(value, expected):
|
||||||
|
return True, value
|
||||||
|
|
||||||
|
# Coercion attempts
|
||||||
|
if expected_type == "string":
|
||||||
|
return True, str(value)
|
||||||
|
|
||||||
|
if expected_type == "integer":
|
||||||
|
if isinstance(value, str) and value.isdigit():
|
||||||
|
return True, int(value)
|
||||||
|
if isinstance(value, float) and value == int(value):
|
||||||
|
return True, int(value)
|
||||||
|
return False, value
|
||||||
|
|
||||||
|
if expected_type == "number":
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return True, float(value)
|
||||||
|
except ValueError:
|
||||||
|
return False, value
|
||||||
|
return False, value
|
||||||
|
|
||||||
|
if expected_type == "boolean":
|
||||||
|
if isinstance(value, str):
|
||||||
|
lower = value.lower()
|
||||||
|
if lower in ("true", "1", "yes"):
|
||||||
|
return True, True
|
||||||
|
if lower in ("false", "0", "no"):
|
||||||
|
return True, False
|
||||||
|
return False, value
|
||||||
|
|
||||||
|
return False, value
|
||||||
|
|
||||||
|
|
||||||
|
def _truncate(s: str, max_len: int) -> str:
|
||||||
|
"""Truncate a string for display."""
|
||||||
|
if len(s) <= max_len:
|
||||||
|
return s
|
||||||
|
return s[:max_len - 3] + "..."
|
||||||
275
tools/session_templates.py
Normal file
275
tools/session_templates.py
Normal file
@@ -0,0 +1,275 @@
|
|||||||
|
"""
|
||||||
|
Session templates for code-first seeding.
|
||||||
|
|
||||||
|
Research: Code-heavy sessions (execute_code dominant in first 30 turns) improve over time.
|
||||||
|
File-heavy sessions degrade. Key is deterministic feedback loops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Any
|
||||||
|
from dataclasses import dataclass, asdict, field
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TEMPLATE_DIR = Path.home() / ".hermes" / "session-templates"
|
||||||
|
|
||||||
|
|
||||||
|
class TaskType(Enum):
|
||||||
|
CODE = "code"
|
||||||
|
FILE = "file"
|
||||||
|
RESEARCH = "research"
|
||||||
|
MIXED = "mixed"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolExample:
|
||||||
|
tool_name: str
|
||||||
|
arguments: Dict[str, Any]
|
||||||
|
result: str
|
||||||
|
success: bool
|
||||||
|
turn: int = 0
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return asdict(self)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data):
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Template:
|
||||||
|
name: str
|
||||||
|
task_type: TaskType
|
||||||
|
examples: List[ToolExample]
|
||||||
|
desc: str = ""
|
||||||
|
created: float = 0.0
|
||||||
|
used: int = 0
|
||||||
|
session_id: Optional[str] = None
|
||||||
|
tags: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.created == 0.0:
|
||||||
|
self.created = time.time()
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
d = asdict(self)
|
||||||
|
d['task_type'] = self.task_type.value
|
||||||
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data):
|
||||||
|
data['task_type'] = TaskType(data['task_type'])
|
||||||
|
data['examples'] = [ToolExample.from_dict(e) for e in data.get('examples', [])]
|
||||||
|
return cls(**data)
|
||||||
|
|
||||||
|
|
||||||
|
class Templates:
|
||||||
|
def __init__(self, dir=None):
|
||||||
|
self.dir = dir or TEMPLATE_DIR
|
||||||
|
self.dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.templates = {}
|
||||||
|
self._load()
|
||||||
|
|
||||||
|
def _load(self):
|
||||||
|
for f in self.dir.glob("*.json"):
|
||||||
|
try:
|
||||||
|
with open(f) as fh:
|
||||||
|
t = Template.from_dict(json.load(fh))
|
||||||
|
self.templates[t.name] = t
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Load failed {f}: {e}")
|
||||||
|
|
||||||
|
def _save(self, t):
|
||||||
|
with open(self.dir / f"{t.name}.json", 'w') as f:
|
||||||
|
json.dump(t.to_dict(), f, indent=2)
|
||||||
|
|
||||||
|
def classify(self, calls):
|
||||||
|
if not calls:
|
||||||
|
return TaskType.MIXED
|
||||||
|
code = {'execute_code', 'code_execution'}
|
||||||
|
file_ops = {'read_file', 'write_file', 'patch', 'search_files'}
|
||||||
|
research = {'web_search', 'web_fetch', 'browser_navigate'}
|
||||||
|
names = [c.get('tool_name', '') for c in calls]
|
||||||
|
total = len(names)
|
||||||
|
if sum(1 for n in names if n in code) / total > 0.6:
|
||||||
|
return TaskType.CODE
|
||||||
|
if sum(1 for n in names if n in file_ops) / total > 0.6:
|
||||||
|
return TaskType.FILE
|
||||||
|
if sum(1 for n in names if n in research) / total > 0.6:
|
||||||
|
return TaskType.RESEARCH
|
||||||
|
return TaskType.MIXED
|
||||||
|
|
||||||
|
def extract(self, session_id, max_n=10):
|
||||||
|
db = Path.home() / ".hermes" / "state.db"
|
||||||
|
if not db.exists():
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(db))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT role, content, tool_calls FROM messages WHERE session_id=? ORDER BY timestamp LIMIT 100",
|
||||||
|
(session_id,)
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
examples = []
|
||||||
|
turn = 0
|
||||||
|
for r in rows:
|
||||||
|
if len(examples) >= max_n:
|
||||||
|
break
|
||||||
|
if r['role'] == 'assistant' and r['tool_calls']:
|
||||||
|
try:
|
||||||
|
for tc in json.loads(r['tool_calls']):
|
||||||
|
if len(examples) >= max_n:
|
||||||
|
break
|
||||||
|
name = tc.get('function', {}).get('name')
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.get('function', {}).get('arguments', '{}'))
|
||||||
|
except:
|
||||||
|
args = {}
|
||||||
|
examples.append(ToolExample(name, args, "", True, turn))
|
||||||
|
turn += 1
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
elif r['role'] == 'tool' and examples and examples[-1].result == "":
|
||||||
|
examples[-1].result = r['content'] or ""
|
||||||
|
return examples
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Extract failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def create(self, session_id, name=None, task_type=None, max_n=10, desc="", tags=None):
|
||||||
|
examples = self.extract(session_id, max_n)
|
||||||
|
if not examples:
|
||||||
|
return None
|
||||||
|
if task_type is None:
|
||||||
|
task_type = self.classify([{'tool_name': e.tool_name} for e in examples])
|
||||||
|
if name is None:
|
||||||
|
name = f"{task_type.value}_{session_id[:8]}_{int(time.time())}"
|
||||||
|
t = Template(name, task_type, examples, desc or f"{len(examples)} examples", time.time(), 0, session_id, tags or [])
|
||||||
|
self.templates[name] = t
|
||||||
|
self._save(t)
|
||||||
|
logger.info(f"Created {name} with {len(examples)} examples")
|
||||||
|
return t
|
||||||
|
|
||||||
|
def get(self, task_type, tags=None):
|
||||||
|
matching = [t for t in self.templates.values() if t.task_type == task_type]
|
||||||
|
if tags:
|
||||||
|
matching = [t for t in matching if any(tag in t.tags for tag in tags)]
|
||||||
|
if not matching:
|
||||||
|
return None
|
||||||
|
matching.sort(key=lambda t: t.used)
|
||||||
|
return matching[0]
|
||||||
|
|
||||||
|
def inject(self, template, messages):
|
||||||
|
if not template.examples:
|
||||||
|
return messages
|
||||||
|
injection = [{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"Template: {template.name} ({template.task_type.value})\n{template.desc}"
|
||||||
|
}]
|
||||||
|
for i, ex in enumerate(template.examples):
|
||||||
|
injection.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": None,
|
||||||
|
"tool_calls": [{
|
||||||
|
"id": f"tpl_{i}",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": ex.tool_name, "arguments": json.dumps(ex.arguments)}
|
||||||
|
}]
|
||||||
|
})
|
||||||
|
injection.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": f"tpl_{i}",
|
||||||
|
"content": ex.result
|
||||||
|
})
|
||||||
|
idx = 0
|
||||||
|
for i, m in enumerate(messages):
|
||||||
|
if m.get("role") != "system":
|
||||||
|
break
|
||||||
|
idx = i + 1
|
||||||
|
for i, m in enumerate(injection):
|
||||||
|
messages.insert(idx + i, m)
|
||||||
|
template.used += 1
|
||||||
|
self._save(template)
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def list(self, task_type=None, tags=None):
|
||||||
|
ts = list(self.templates.values())
|
||||||
|
if task_type:
|
||||||
|
ts = [t for t in ts if t.task_type == task_type]
|
||||||
|
if tags:
|
||||||
|
ts = [t for t in ts if any(tag in t.tags for tag in tags)]
|
||||||
|
ts.sort(key=lambda t: t.created, reverse=True)
|
||||||
|
return ts
|
||||||
|
|
||||||
|
def delete(self, name):
|
||||||
|
if name not in self.templates:
|
||||||
|
return False
|
||||||
|
del self.templates[name]
|
||||||
|
p = self.dir / f"{name}.json"
|
||||||
|
if p.exists():
|
||||||
|
p.unlink()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stats(self):
|
||||||
|
if not self.templates:
|
||||||
|
return {"total": 0, "by_type": {}, "examples": 0, "usage": 0}
|
||||||
|
by_type = {}
|
||||||
|
total_ex = 0
|
||||||
|
total_use = 0
|
||||||
|
for t in self.templates.values():
|
||||||
|
by_type[t.task_type.value] = by_type.get(t.task_type.value, 0) + 1
|
||||||
|
total_ex += len(t.examples)
|
||||||
|
total_use += t.used
|
||||||
|
return {"total": len(self.templates), "by_type": by_type, "examples": total_ex, "usage": total_use}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
s = p.add_subparsers(dest="cmd")
|
||||||
|
lp = s.add_parser("list")
|
||||||
|
lp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||||
|
lp.add_argument("--tags")
|
||||||
|
cp = s.add_parser("create")
|
||||||
|
cp.add_argument("session_id")
|
||||||
|
cp.add_argument("--name")
|
||||||
|
cp.add_argument("--type", choices=["code", "file", "research", "mixed"])
|
||||||
|
cp.add_argument("--max", type=int, default=10)
|
||||||
|
cp.add_argument("--desc")
|
||||||
|
cp.add_argument("--tags")
|
||||||
|
dp = s.add_parser("delete")
|
||||||
|
dp.add_argument("name")
|
||||||
|
sp = s.add_parser("stats")
|
||||||
|
args = p.parse_args()
|
||||||
|
ts = Templates()
|
||||||
|
if args.cmd == "list":
|
||||||
|
tt = TaskType(args.type) if args.type else None
|
||||||
|
tags = args.tags.split(",") if args.tags else None
|
||||||
|
for t in ts.list(tt, tags):
|
||||||
|
print(f"{t.name}: {t.task_type.value} ({len(t.examples)} ex, used {t.used}x)")
|
||||||
|
elif args.cmd == "create":
|
||||||
|
tt = TaskType(args.type) if args.type else None
|
||||||
|
tags = args.tags.split(",") if args.tags else None
|
||||||
|
t = ts.create(args.session_id, args.name, tt, args.max, args.desc or "", tags)
|
||||||
|
if t:
|
||||||
|
print(f"Created: {t.name} ({len(t.examples)} examples)")
|
||||||
|
else:
|
||||||
|
print("Failed")
|
||||||
|
elif args.cmd == "delete":
|
||||||
|
print("Deleted" if ts.delete(args.name) else "Not found")
|
||||||
|
elif args.cmd == "stats":
|
||||||
|
s = ts.stats()
|
||||||
|
print(f"Total: {s['total']}, Examples: {s['examples']}, Usage: {s['usage']}")
|
||||||
|
for k, v in s['by_type'].items():
|
||||||
|
print(f" {k}: {v}")
|
||||||
|
else:
|
||||||
|
p.print_help()
|
||||||
@@ -38,12 +38,41 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from hermes_constants import get_hermes_home, display_hermes_home
|
from hermes_constants import get_hermes_home, display_hermes_home
|
||||||
from typing import Dict, Any, Optional, Tuple
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_error(
|
||||||
|
message: str,
|
||||||
|
skill_name: str = None,
|
||||||
|
file_path: str = None,
|
||||||
|
suggestion: str = None,
|
||||||
|
context: dict = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Format an error with rich context for better debugging."""
|
||||||
|
parts = [message]
|
||||||
|
if skill_name:
|
||||||
|
parts.append(f"Skill: {skill_name}")
|
||||||
|
if file_path:
|
||||||
|
parts.append(f"File: {file_path}")
|
||||||
|
if suggestion:
|
||||||
|
parts.append(f"Suggestion: {suggestion}")
|
||||||
|
if context:
|
||||||
|
for key, value in context.items():
|
||||||
|
parts.append(f"{key}: {value}")
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": " | ".join(parts),
|
||||||
|
"skill_name": skill_name,
|
||||||
|
"file_path": file_path,
|
||||||
|
"suggestion": suggestion,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# Import security scanner — agent-created skills get the same scrutiny as
|
# Import security scanner — agent-created skills get the same scrutiny as
|
||||||
# community hub installs.
|
# community hub installs.
|
||||||
try:
|
try:
|
||||||
@@ -253,6 +282,94 @@ def _resolve_skill_target(skill_dir: Path, file_path: str) -> Tuple[Optional[Pat
|
|||||||
return target, None
|
return target, None
|
||||||
|
|
||||||
|
|
||||||
|
MAX_BACKUPS_PER_FILE = 3
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_skill_file(file_path: Path) -> Optional[Path]:
|
||||||
|
"""Create a timestamped backup of a skill file before modification.
|
||||||
|
|
||||||
|
The backup is named ``{original_name}.bak.{unix_timestamp}`` and placed
|
||||||
|
in the same directory. Returns the backup path, or *None* if the file
|
||||||
|
does not exist yet (nothing to back up).
|
||||||
|
"""
|
||||||
|
if not file_path.exists():
|
||||||
|
return None
|
||||||
|
timestamp = int(time.time())
|
||||||
|
backup_path = file_path.parent / f"{file_path.name}.bak.{timestamp}"
|
||||||
|
shutil.copy2(str(file_path), str(backup_path))
|
||||||
|
return backup_path
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_old_backups(file_path: Path, max_backups: int = MAX_BACKUPS_PER_FILE) -> None:
|
||||||
|
"""Prune backup files so at most *max_backups* are retained.
|
||||||
|
|
||||||
|
Backups match the pattern ``{file_path.name}.bak.*`` in the same
|
||||||
|
directory. The oldest (by mtime) are removed first.
|
||||||
|
"""
|
||||||
|
parent = file_path.parent
|
||||||
|
prefix = file_path.name + ".bak."
|
||||||
|
try:
|
||||||
|
backups: List[Path] = sorted(
|
||||||
|
[f for f in parent.iterdir() if f.name.startswith(prefix) and f.is_file()],
|
||||||
|
key=lambda p: p.stat().st_mtime,
|
||||||
|
)
|
||||||
|
except OSError:
|
||||||
|
return
|
||||||
|
while len(backups) > max_backups:
|
||||||
|
try:
|
||||||
|
backups.pop(0).unlink()
|
||||||
|
except OSError:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_written_file(file_path: Path, is_skill_md: bool = False) -> Optional[str]:
|
||||||
|
"""Re-read a file from disk and validate it after writing.
|
||||||
|
|
||||||
|
Catches filesystem-level issues (truncation, encoding errors, empty
|
||||||
|
writes) that pre-write validation cannot detect. For SKILL.md files
|
||||||
|
the frontmatter is also re-validated.
|
||||||
|
|
||||||
|
Returns an error message, or *None* if the file looks healthy.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
content = file_path.read_text(encoding="utf-8")
|
||||||
|
except OSError as exc:
|
||||||
|
return f"Failed to read file after write: {exc}"
|
||||||
|
except UnicodeDecodeError as exc:
|
||||||
|
return f"File encoding error after write: {exc}"
|
||||||
|
|
||||||
|
if len(content) == 0:
|
||||||
|
return "File is empty after write (possible truncation)."
|
||||||
|
|
||||||
|
if is_skill_md:
|
||||||
|
err = _validate_frontmatter(content)
|
||||||
|
if err:
|
||||||
|
return f"Post-write validation failed: {err}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _revert_from_backup(file_path: Path, backup_path: Optional[Path]) -> None:
|
||||||
|
"""Restore *file_path* from *backup_path*.
|
||||||
|
|
||||||
|
If *backup_path* is None or missing the target file is removed so the
|
||||||
|
skill directory is at least not left with corrupted content.
|
||||||
|
"""
|
||||||
|
if backup_path and backup_path.exists():
|
||||||
|
try:
|
||||||
|
shutil.copy2(str(backup_path), str(file_path))
|
||||||
|
except OSError:
|
||||||
|
logger.error(
|
||||||
|
"Failed to restore %s from backup %s", file_path, backup_path, exc_info=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No backup — remove the partially-written file
|
||||||
|
try:
|
||||||
|
file_path.unlink(missing_ok=True)
|
||||||
|
except OSError:
|
||||||
|
logger.error("Failed to remove corrupted file %s after failed write", file_path, exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
def _atomic_write_text(file_path: Path, content: str, encoding: str = "utf-8") -> None:
|
||||||
"""
|
"""
|
||||||
Atomically write text content to a file.
|
Atomically write text content to a file.
|
||||||
@@ -358,20 +475,35 @@ def _edit_skill(name: str, content: str) -> Dict[str, Any]:
|
|||||||
|
|
||||||
existing = _find_skill(name)
|
existing = _find_skill(name)
|
||||||
if not existing:
|
if not existing:
|
||||||
return {"success": False, "error": f"Skill '{name}' not found. Use skills_list() to see available skills."}
|
return _format_error(
|
||||||
|
f"Skill '{name}' not found.",
|
||||||
|
skill_name=name,
|
||||||
|
suggestion="Use skills_list() to see available skills.",
|
||||||
|
)
|
||||||
|
|
||||||
skill_md = existing["path"] / "SKILL.md"
|
skill_md = existing["path"] / "SKILL.md"
|
||||||
# Back up original content for rollback
|
|
||||||
original_content = skill_md.read_text(encoding="utf-8") if skill_md.exists() else None
|
# --- Transactional write-validate-commit-or-rollback ---
|
||||||
|
backup_path = _backup_skill_file(skill_md)
|
||||||
_atomic_write_text(skill_md, content)
|
_atomic_write_text(skill_md, content)
|
||||||
|
|
||||||
|
# Post-write validation: catch filesystem-level failures
|
||||||
|
validate_err = _validate_written_file(skill_md, is_skill_md=True)
|
||||||
|
if validate_err:
|
||||||
|
_revert_from_backup(skill_md, backup_path)
|
||||||
|
return {"success": False, "error": f"Edit reverted: {validate_err}"}
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(existing["path"])
|
scan_error = _security_scan_skill(existing["path"])
|
||||||
if scan_error:
|
if scan_error:
|
||||||
if original_content is not None:
|
_revert_from_backup(skill_md, backup_path)
|
||||||
_atomic_write_text(skill_md, original_content)
|
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
||||||
|
# Success — remove the backup we just created, prune any older ones
|
||||||
|
if backup_path:
|
||||||
|
backup_path.unlink(missing_ok=True)
|
||||||
|
_cleanup_old_backups(skill_md)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"Skill '{name}' updated.",
|
"message": f"Skill '{name}' updated.",
|
||||||
@@ -392,13 +524,25 @@ def _patch_skill(
|
|||||||
Requires a unique match unless replace_all is True.
|
Requires a unique match unless replace_all is True.
|
||||||
"""
|
"""
|
||||||
if not old_string:
|
if not old_string:
|
||||||
return {"success": False, "error": "old_string is required for 'patch'."}
|
return _format_error(
|
||||||
|
"old_string is required for 'patch'.",
|
||||||
|
skill_name=name,
|
||||||
|
suggestion="Provide the exact text to find in the skill file.",
|
||||||
|
)
|
||||||
if new_string is None:
|
if new_string is None:
|
||||||
return {"success": False, "error": "new_string is required for 'patch'. Use an empty string to delete matched text."}
|
return _format_error(
|
||||||
|
"new_string is required for 'patch'. Use an empty string to delete matched text.",
|
||||||
|
skill_name=name,
|
||||||
|
suggestion="Pass new_string='' to delete the matched text.",
|
||||||
|
)
|
||||||
|
|
||||||
existing = _find_skill(name)
|
existing = _find_skill(name)
|
||||||
if not existing:
|
if not existing:
|
||||||
return {"success": False, "error": f"Skill '{name}' not found."}
|
return _format_error(
|
||||||
|
f"Skill '{name}' not found.",
|
||||||
|
skill_name=name,
|
||||||
|
suggestion="Use skills_list() to see available skills.",
|
||||||
|
)
|
||||||
|
|
||||||
skill_dir = existing["path"]
|
skill_dir = existing["path"]
|
||||||
|
|
||||||
@@ -452,15 +596,29 @@ def _patch_skill(
|
|||||||
"error": f"Patch would break SKILL.md structure: {err}",
|
"error": f"Patch would break SKILL.md structure: {err}",
|
||||||
}
|
}
|
||||||
|
|
||||||
original_content = content # for rollback
|
is_skill_md = not file_path
|
||||||
|
|
||||||
|
# --- Transactional write-validate-commit-or-rollback ---
|
||||||
|
backup_path = _backup_skill_file(target)
|
||||||
_atomic_write_text(target, new_content)
|
_atomic_write_text(target, new_content)
|
||||||
|
|
||||||
|
# Post-write validation
|
||||||
|
validate_err = _validate_written_file(target, is_skill_md=is_skill_md)
|
||||||
|
if validate_err:
|
||||||
|
_revert_from_backup(target, backup_path)
|
||||||
|
return {"success": False, "error": f"Patch reverted: {validate_err}"}
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(skill_dir)
|
scan_error = _security_scan_skill(skill_dir)
|
||||||
if scan_error:
|
if scan_error:
|
||||||
_atomic_write_text(target, original_content)
|
_revert_from_backup(target, backup_path)
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
||||||
|
# Success — remove the backup we just created, prune any older ones
|
||||||
|
if backup_path:
|
||||||
|
backup_path.unlink(missing_ok=True)
|
||||||
|
_cleanup_old_backups(target)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({match_count} replacement{'s' if match_count > 1 else ''}).",
|
"message": f"Patched {'SKILL.md' if not file_path else file_path} in skill '{name}' ({match_count} replacement{'s' if match_count > 1 else ''}).",
|
||||||
@@ -519,19 +677,28 @@ def _write_file(name: str, file_path: str, file_content: str) -> Dict[str, Any]:
|
|||||||
if err:
|
if err:
|
||||||
return {"success": False, "error": err}
|
return {"success": False, "error": err}
|
||||||
target.parent.mkdir(parents=True, exist_ok=True)
|
target.parent.mkdir(parents=True, exist_ok=True)
|
||||||
# Back up for rollback
|
|
||||||
original_content = target.read_text(encoding="utf-8") if target.exists() else None
|
# --- Transactional write-validate-commit-or-rollback ---
|
||||||
|
backup_path = _backup_skill_file(target)
|
||||||
_atomic_write_text(target, file_content)
|
_atomic_write_text(target, file_content)
|
||||||
|
|
||||||
|
# Post-write validation: ensure the file is readable and non-empty
|
||||||
|
validate_err = _validate_written_file(target, is_skill_md=False)
|
||||||
|
if validate_err:
|
||||||
|
_revert_from_backup(target, backup_path)
|
||||||
|
return {"success": False, "error": f"Write reverted: {validate_err}"}
|
||||||
|
|
||||||
# Security scan — roll back on block
|
# Security scan — roll back on block
|
||||||
scan_error = _security_scan_skill(existing["path"])
|
scan_error = _security_scan_skill(existing["path"])
|
||||||
if scan_error:
|
if scan_error:
|
||||||
if original_content is not None:
|
_revert_from_backup(target, backup_path)
|
||||||
_atomic_write_text(target, original_content)
|
|
||||||
else:
|
|
||||||
target.unlink(missing_ok=True)
|
|
||||||
return {"success": False, "error": scan_error}
|
return {"success": False, "error": scan_error}
|
||||||
|
|
||||||
|
# Success — remove the backup we just created, prune any older ones
|
||||||
|
if backup_path:
|
||||||
|
backup_path.unlink(missing_ok=True)
|
||||||
|
_cleanup_old_backups(target)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"success": True,
|
"success": True,
|
||||||
"message": f"File '{file_path}' written to skill '{name}'.",
|
"message": f"File '{file_path}' written to skill '{name}'.",
|
||||||
|
|||||||
312
tools/tool_validator.py
Normal file
312
tools/tool_validator.py
Normal file
@@ -0,0 +1,312 @@
|
|||||||
|
"""
|
||||||
|
Poka-Yoke: Tool Hallucination Detection — #922.
|
||||||
|
|
||||||
|
Validation firewall between LLM tool-call output and actual execution.
|
||||||
|
|
||||||
|
Detects and blocks:
|
||||||
|
1. Unknown tool names (hallucinated tools)
|
||||||
|
2. Malformed parameters (wrong types)
|
||||||
|
3. Missing required arguments
|
||||||
|
4. Extra unknown parameters
|
||||||
|
|
||||||
|
Poka-Yoke Type: Detection (catches errors at the boundary before harm)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, List, Optional, Set
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationSeverity(Enum):
|
||||||
|
"""Severity of validation failure."""
|
||||||
|
BLOCK = "block" # Must block execution
|
||||||
|
WARN = "warn" # Warning, may proceed
|
||||||
|
INFO = "info" # Informational
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationIssue:
|
||||||
|
"""A validation issue found."""
|
||||||
|
severity: ValidationSeverity
|
||||||
|
code: str
|
||||||
|
message: str
|
||||||
|
tool_name: str
|
||||||
|
parameter: Optional[str] = None
|
||||||
|
expected: Optional[str] = None
|
||||||
|
actual: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Result of tool call validation."""
|
||||||
|
valid: bool
|
||||||
|
tool_name: str
|
||||||
|
issues: List[ValidationIssue] = field(default_factory=list)
|
||||||
|
corrected_args: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def blocking_issues(self) -> List[ValidationIssue]:
|
||||||
|
return [i for i in self.issues if i.severity == ValidationSeverity.BLOCK]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def warnings(self) -> List[ValidationIssue]:
|
||||||
|
return [i for i in self.issues if i.severity == ValidationSeverity.WARN]
|
||||||
|
|
||||||
|
|
||||||
|
class ToolHallucinationDetector:
|
||||||
|
"""
|
||||||
|
Poka-yoke detector for tool hallucinations.
|
||||||
|
|
||||||
|
Validates tool calls against registered schemas before execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tool_registry: Optional[Dict] = None):
|
||||||
|
"""
|
||||||
|
Initialize detector.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_registry: Dict of tool_name -> tool_schema
|
||||||
|
"""
|
||||||
|
self.registry = tool_registry or {}
|
||||||
|
self._rejection_log: List[Dict] = []
|
||||||
|
|
||||||
|
def register_tool(self, name: str, schema: Dict):
|
||||||
|
"""Register a tool with its JSON Schema."""
|
||||||
|
self.registry[name] = schema
|
||||||
|
|
||||||
|
def register_tools(self, tools: Dict[str, Dict]):
|
||||||
|
"""Register multiple tools."""
|
||||||
|
self.registry.update(tools)
|
||||||
|
|
||||||
|
def validate_tool_call(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
arguments: Dict[str, Any],
|
||||||
|
model: str = "unknown",
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""
|
||||||
|
Validate a tool call against the registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool being called
|
||||||
|
arguments: Arguments passed to the tool
|
||||||
|
model: Model that generated the call (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with validation status
|
||||||
|
"""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# 1. Check if tool exists
|
||||||
|
if tool_name not in self.registry:
|
||||||
|
issue = ValidationIssue(
|
||||||
|
severity=ValidationSeverity.BLOCK,
|
||||||
|
code="UNKNOWN_TOOL",
|
||||||
|
message=f"Tool '{tool_name}' does not exist. Available: {', '.join(sorted(self.registry.keys())[:10])}...",
|
||||||
|
tool_name=tool_name,
|
||||||
|
)
|
||||||
|
issues.append(issue)
|
||||||
|
self._log_rejection(tool_name, arguments, model, "UNKNOWN_TOOL")
|
||||||
|
return ValidationResult(valid=False, tool_name=tool_name, issues=issues)
|
||||||
|
|
||||||
|
schema = self.registry[tool_name]
|
||||||
|
params_schema = schema.get("parameters", {}).get("properties", {})
|
||||||
|
required = set(schema.get("parameters", {}).get("required", []))
|
||||||
|
|
||||||
|
# 2. Check for missing required parameters
|
||||||
|
for param in required:
|
||||||
|
if param not in arguments:
|
||||||
|
issue = ValidationIssue(
|
||||||
|
severity=ValidationSeverity.BLOCK,
|
||||||
|
code="MISSING_REQUIRED",
|
||||||
|
message=f"Missing required parameter: {param}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameter=param,
|
||||||
|
)
|
||||||
|
issues.append(issue)
|
||||||
|
|
||||||
|
# 3. Check parameter types
|
||||||
|
for param_name, param_value in arguments.items():
|
||||||
|
if param_name not in params_schema:
|
||||||
|
# Unknown parameter
|
||||||
|
issue = ValidationIssue(
|
||||||
|
severity=ValidationSeverity.WARN,
|
||||||
|
code="UNKNOWN_PARAM",
|
||||||
|
message=f"Unknown parameter: {param_name}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameter=param_name,
|
||||||
|
)
|
||||||
|
issues.append(issue)
|
||||||
|
continue
|
||||||
|
|
||||||
|
param_schema = params_schema[param_name]
|
||||||
|
expected_type = param_schema.get("type")
|
||||||
|
|
||||||
|
if expected_type and not self._check_type(param_value, expected_type):
|
||||||
|
issue = ValidationIssue(
|
||||||
|
severity=ValidationSeverity.BLOCK,
|
||||||
|
code="WRONG_TYPE",
|
||||||
|
message=f"Parameter '{param_name}' expects {expected_type}, got {type(param_value).__name__}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameter=param_name,
|
||||||
|
expected=expected_type,
|
||||||
|
actual=type(param_value).__name__,
|
||||||
|
)
|
||||||
|
issues.append(issue)
|
||||||
|
|
||||||
|
# 4. Check for common hallucination patterns
|
||||||
|
hallucination_issues = self._detect_hallucination_patterns(tool_name, arguments)
|
||||||
|
issues.extend(hallucination_issues)
|
||||||
|
|
||||||
|
# Determine validity
|
||||||
|
has_blocking = any(i.severity == ValidationSeverity.BLOCK for i in issues)
|
||||||
|
|
||||||
|
if has_blocking:
|
||||||
|
self._log_rejection(tool_name, arguments, model,
|
||||||
|
"; ".join(i.code for i in issues if i.severity == ValidationSeverity.BLOCK))
|
||||||
|
|
||||||
|
return ValidationResult(
|
||||||
|
valid=not has_blocking,
|
||||||
|
tool_name=tool_name,
|
||||||
|
issues=issues,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_type(self, value: Any, expected_type: str) -> bool:
|
||||||
|
"""Check if value matches expected JSON Schema type."""
|
||||||
|
type_map = {
|
||||||
|
"string": str,
|
||||||
|
"number": (int, float),
|
||||||
|
"integer": int,
|
||||||
|
"boolean": bool,
|
||||||
|
"array": list,
|
||||||
|
"object": dict,
|
||||||
|
}
|
||||||
|
|
||||||
|
expected = type_map.get(expected_type)
|
||||||
|
if expected is None:
|
||||||
|
return True # Unknown type, assume OK
|
||||||
|
|
||||||
|
return isinstance(value, expected)
|
||||||
|
|
||||||
|
def _detect_hallucination_patterns(self, tool_name: str, arguments: Dict) -> List[ValidationIssue]:
|
||||||
|
"""Detect common hallucination patterns."""
|
||||||
|
issues = []
|
||||||
|
|
||||||
|
# Pattern 1: Placeholder values
|
||||||
|
placeholder_patterns = [
|
||||||
|
r"^<.*>$", # <placeholder>
|
||||||
|
r"^\[.*\]$", # [placeholder]
|
||||||
|
r"^TODO$|^FIXME$", # TODO/FIXME
|
||||||
|
r"^example\.com$", # example.com
|
||||||
|
r"^127\.0\.0\.1$", # localhost
|
||||||
|
]
|
||||||
|
|
||||||
|
for param_name, param_value in arguments.items():
|
||||||
|
if isinstance(param_value, str):
|
||||||
|
for pattern in placeholder_patterns:
|
||||||
|
if re.match(pattern, param_value, re.IGNORECASE):
|
||||||
|
issues.append(ValidationIssue(
|
||||||
|
severity=ValidationSeverity.WARN,
|
||||||
|
code="PLACEHOLDER_VALUE",
|
||||||
|
message=f"Parameter '{param_name}' contains placeholder: {param_value}",
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameter=param_name,
|
||||||
|
))
|
||||||
|
|
||||||
|
# Pattern 2: Suspiciously long strings (might be hallucinated content)
|
||||||
|
for param_name, param_value in arguments.items():
|
||||||
|
if isinstance(param_value, str) and len(param_value) > 10000:
|
||||||
|
issues.append(ValidationIssue(
|
||||||
|
severity=ValidationSeverity.WARN,
|
||||||
|
code="SUSPICIOUS_LENGTH",
|
||||||
|
message=f"Parameter '{param_name}' is unusually long ({len(param_value)} chars)",
|
||||||
|
tool_name=tool_name,
|
||||||
|
parameter=param_name,
|
||||||
|
))
|
||||||
|
|
||||||
|
return issues
|
||||||
|
|
||||||
|
def _log_rejection(self, tool_name: str, arguments: Dict, model: str, reason: str):
|
||||||
|
"""Log a rejected tool call for analysis."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
entry = {
|
||||||
|
"timestamp": time.time(),
|
||||||
|
"tool_name": tool_name,
|
||||||
|
"arguments": {k: str(v)[:100] for k, v in arguments.items()},
|
||||||
|
"model": model,
|
||||||
|
"reason": reason,
|
||||||
|
}
|
||||||
|
|
||||||
|
self._rejection_log.append(entry)
|
||||||
|
|
||||||
|
# Keep log bounded
|
||||||
|
if len(self._rejection_log) > 1000:
|
||||||
|
self._rejection_log = self._rejection_log[-500:]
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Tool hallucination blocked: tool=%s, model=%s, reason=%s",
|
||||||
|
tool_name, model, reason
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_rejection_stats(self) -> Dict:
|
||||||
|
"""Get statistics on rejected tool calls."""
|
||||||
|
if not self._rejection_log:
|
||||||
|
return {"total": 0, "by_reason": {}, "by_tool": {}}
|
||||||
|
|
||||||
|
by_reason = {}
|
||||||
|
by_tool = {}
|
||||||
|
|
||||||
|
for entry in self._rejection_log:
|
||||||
|
reason = entry["reason"]
|
||||||
|
tool = entry["tool_name"]
|
||||||
|
|
||||||
|
by_reason[reason] = by_reason.get(reason, 0) + 1
|
||||||
|
by_tool[tool] = by_tool.get(tool, 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": len(self._rejection_log),
|
||||||
|
"by_reason": by_reason,
|
||||||
|
"by_tool": by_tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
def format_validation_report(self, result: ValidationResult) -> str:
|
||||||
|
"""Format validation result as human-readable report."""
|
||||||
|
if result.valid:
|
||||||
|
return f"✅ {result.tool_name}: valid"
|
||||||
|
|
||||||
|
lines = [f"❌ {result.tool_name}: BLOCKED"]
|
||||||
|
for issue in result.blocking_issues:
|
||||||
|
lines.append(f" [{issue.code}] {issue.message}")
|
||||||
|
|
||||||
|
for issue in result.warnings:
|
||||||
|
lines.append(f" ⚠️ [{issue.code}] {issue.message}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def create_rejection_response(result: ValidationResult) -> Dict:
|
||||||
|
"""
|
||||||
|
Create a tool result for a rejected tool call.
|
||||||
|
|
||||||
|
This allows the agent to see the rejection and self-correct.
|
||||||
|
"""
|
||||||
|
issues_text = "\n".join(
|
||||||
|
f"- [{i.code}] {i.message}"
|
||||||
|
for i in result.blocking_issues
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"content": f"""Tool call rejected: {result.tool_name}
|
||||||
|
|
||||||
|
Issues found:
|
||||||
|
{issues_text}
|
||||||
|
|
||||||
|
Please check the tool name and parameters, then try again with valid arguments.""",
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user