Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0822837ec3 |
@@ -1,302 +0,0 @@
|
||||
"""Self-Modifying Prompt Engine — agent learns from its own failures.
|
||||
|
||||
Analyzes session transcripts, identifies failure patterns, and generates
|
||||
prompt patches to prevent future failures.
|
||||
|
||||
The loop: fail → analyze → rewrite → retry → verify improvement.
|
||||
|
||||
Usage:
|
||||
from agent.self_modify import PromptLearner
|
||||
learner = PromptLearner()
|
||||
patches = learner.analyze_session(session_id)
|
||||
learner.apply_patches(patches)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
PATCHES_DIR = HERMES_HOME / "prompt_patches"
|
||||
ROLLBACK_DIR = HERMES_HOME / "prompt_rollback"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FailurePattern:
|
||||
"""A detected failure pattern in session transcripts."""
|
||||
pattern_type: str # retry_loop, timeout, error_hallucination, context_loss
|
||||
description: str
|
||||
frequency: int
|
||||
example_messages: List[str] = field(default_factory=list)
|
||||
suggested_fix: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptPatch:
|
||||
"""A modification to the system prompt based on failure analysis."""
|
||||
id: str
|
||||
failure_type: str
|
||||
original_rule: str
|
||||
new_rule: str
|
||||
confidence: float
|
||||
applied_at: Optional[float] = None
|
||||
reverted: bool = False
|
||||
|
||||
|
||||
# Failure detection patterns
|
||||
FAILURE_SIGNALS = {
|
||||
"retry_loop": {
|
||||
"patterns": [
|
||||
r"(?i)retry(?:ing)?\s*(?:attempt|again)",
|
||||
r"(?i)failed.*retrying",
|
||||
r"(?i)error.*again",
|
||||
r"(?i)attempt\s+\d+\s*(?:of|/)\s*\d+",
|
||||
],
|
||||
"description": "Agent stuck in retry loop",
|
||||
},
|
||||
"timeout": {
|
||||
"patterns": [
|
||||
r"(?i)timed?\s*out",
|
||||
r"(?i)deadline\s+exceeded",
|
||||
r"(?i)took\s+(?:too\s+)?long",
|
||||
],
|
||||
"description": "Operation timed out",
|
||||
},
|
||||
"hallucination": {
|
||||
"patterns": [
|
||||
r"(?i)i\s+(?:don't|do\s+not)\s+(?:have|see|find)\s+(?:any|that|this)\s+(?:information|data|file)",
|
||||
r"(?i)the\s+file\s+doesn't\s+exist",
|
||||
r"(?i)i\s+(?:made|invented|fabricated)\s+(?:that\s+up|this)",
|
||||
],
|
||||
"description": "Agent hallucinated or fabricated information",
|
||||
},
|
||||
"context_loss": {
|
||||
"patterns": [
|
||||
r"(?i)i\s+(?:don't|do\s+not)\s+(?:remember|recall|know)\s+(?:what|where|when|how)",
|
||||
r"(?i)could\s+you\s+remind\s+me",
|
||||
r"(?i)what\s+were\s+we\s+(?:doing|working|talking)\s+(?:on|about)",
|
||||
],
|
||||
"description": "Agent lost context from earlier in conversation",
|
||||
},
|
||||
"tool_failure": {
|
||||
"patterns": [
|
||||
r"(?i)tool\s+(?:call|execution)\s+failed",
|
||||
r"(?i)command\s+not\s+found",
|
||||
r"(?i)permission\s+denied",
|
||||
r"(?i)no\s+such\s+file",
|
||||
],
|
||||
"description": "Tool execution failed",
|
||||
},
|
||||
}
|
||||
|
||||
# Prompt improvement templates
|
||||
PROMPT_FIXES = {
|
||||
"retry_loop": (
|
||||
"If an operation fails more than twice, stop retrying. "
|
||||
"Report the failure and ask the user for guidance. "
|
||||
"Do not enter retry loops — they waste tokens."
|
||||
),
|
||||
"timeout": (
|
||||
"For operations that may take long, set a timeout and report "
|
||||
"progress. If an operation takes more than 30 seconds, report "
|
||||
"what you've done so far and ask if you should continue."
|
||||
),
|
||||
"hallucination": (
|
||||
"If you cannot find information, say 'I don't know' or "
|
||||
"'I couldn't find that.' Never fabricate information. "
|
||||
"If a file doesn't exist, say so — don't guess its contents."
|
||||
),
|
||||
"context_loss": (
|
||||
"When you need context from earlier in the conversation, "
|
||||
"use session_search to find it. Don't ask the user to repeat themselves."
|
||||
),
|
||||
"tool_failure": (
|
||||
"If a tool fails, check the error message and try a different approach. "
|
||||
"Don't retry the exact same command — diagnose first."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class PromptLearner:
|
||||
"""Analyze session transcripts and generate prompt improvements."""
|
||||
|
||||
def __init__(self):
|
||||
PATCHES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ROLLBACK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def analyze_session(self, session_data: dict) -> List[FailurePattern]:
|
||||
"""Analyze a session for failure patterns.
|
||||
|
||||
Args:
|
||||
session_data: Session dict with 'messages' list.
|
||||
|
||||
Returns:
|
||||
List of detected failure patterns.
|
||||
"""
|
||||
messages = session_data.get("messages", [])
|
||||
patterns_found: Dict[str, FailurePattern] = {}
|
||||
|
||||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
role = msg.get("role", "")
|
||||
|
||||
# Only analyze assistant messages and tool results
|
||||
if role not in ("assistant", "tool"):
|
||||
continue
|
||||
|
||||
for failure_type, config in FAILURE_SIGNALS.items():
|
||||
for pattern in config["patterns"]:
|
||||
if re.search(pattern, content):
|
||||
if failure_type not in patterns_found:
|
||||
patterns_found[failure_type] = FailurePattern(
|
||||
pattern_type=failure_type,
|
||||
description=config["description"],
|
||||
frequency=0,
|
||||
suggested_fix=PROMPT_FIXES.get(failure_type, ""),
|
||||
)
|
||||
patterns_found[failure_type].frequency += 1
|
||||
if len(patterns_found[failure_type].example_messages) < 3:
|
||||
patterns_found[failure_type].example_messages.append(
|
||||
content[:200]
|
||||
)
|
||||
break # One match per message per type is enough
|
||||
|
||||
return list(patterns_found.values())
|
||||
|
||||
def generate_patches(self, patterns: List[FailurePattern],
|
||||
min_confidence: float = 0.7) -> List[PromptPatch]:
|
||||
"""Generate prompt patches from failure patterns.
|
||||
|
||||
Args:
|
||||
patterns: Detected failure patterns.
|
||||
min_confidence: Minimum confidence to generate a patch.
|
||||
|
||||
Returns:
|
||||
List of prompt patches.
|
||||
"""
|
||||
patches = []
|
||||
for pattern in patterns:
|
||||
# Confidence based on frequency
|
||||
if pattern.frequency >= 3:
|
||||
confidence = 0.9
|
||||
elif pattern.frequency >= 2:
|
||||
confidence = 0.75
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
|
||||
if not pattern.suggested_fix:
|
||||
continue
|
||||
|
||||
patch = PromptPatch(
|
||||
id=f"{pattern.pattern_type}-{int(time.time())}",
|
||||
failure_type=pattern.pattern_type,
|
||||
original_rule="(missing — no existing rule for this pattern)",
|
||||
new_rule=pattern.suggested_fix,
|
||||
confidence=confidence,
|
||||
)
|
||||
patches.append(patch)
|
||||
|
||||
return patches
|
||||
|
||||
def apply_patches(self, patches: List[PromptPatch],
|
||||
prompt_path: Optional[str] = None) -> int:
|
||||
"""Apply patches to the system prompt.
|
||||
|
||||
Args:
|
||||
patches: Patches to apply.
|
||||
prompt_path: Path to prompt file (default: ~/.hermes/system_prompt.md)
|
||||
|
||||
Returns:
|
||||
Number of patches applied.
|
||||
"""
|
||||
if prompt_path is None:
|
||||
prompt_path = str(HERMES_HOME / "system_prompt.md")
|
||||
|
||||
prompt_file = Path(prompt_path)
|
||||
|
||||
# Backup current prompt
|
||||
if prompt_file.exists():
|
||||
backup = ROLLBACK_DIR / f"{prompt_file.name}.{int(time.time())}.bak"
|
||||
backup.write_text(prompt_file.read_text())
|
||||
|
||||
# Read current prompt
|
||||
current = prompt_file.read_text() if prompt_file.exists() else ""
|
||||
|
||||
# Apply patches
|
||||
applied = 0
|
||||
additions = []
|
||||
for patch in patches:
|
||||
if patch.new_rule not in current:
|
||||
additions.append(f"\n## Auto-learned: {patch.failure_type}\n{patch.new_rule}")
|
||||
patch.applied_at = time.time()
|
||||
applied += 1
|
||||
|
||||
if additions:
|
||||
new_content = current + "\n".join(additions)
|
||||
prompt_file.write_text(new_content)
|
||||
|
||||
# Log patches
|
||||
patches_file = PATCHES_DIR / f"patches-{int(time.time())}.json"
|
||||
with open(patches_file, "w") as f:
|
||||
json.dump([p.__dict__ for p in patches], f, indent=2, default=str)
|
||||
|
||||
logger.info("Applied %d prompt patches", applied)
|
||||
return applied
|
||||
|
||||
def rollback_last(self, prompt_path: Optional[str] = None) -> bool:
|
||||
"""Rollback to the most recent backup.
|
||||
|
||||
Args:
|
||||
prompt_path: Path to prompt file.
|
||||
|
||||
Returns:
|
||||
True if rollback succeeded.
|
||||
"""
|
||||
if prompt_path is None:
|
||||
prompt_path = str(HERMES_HOME / "system_prompt.md")
|
||||
|
||||
backups = sorted(ROLLBACK_DIR.glob("*.bak"), reverse=True)
|
||||
if not backups:
|
||||
logger.warning("No backups to rollback to")
|
||||
return False
|
||||
|
||||
latest = backups[0]
|
||||
Path(prompt_path).write_text(latest.read_text())
|
||||
logger.info("Rolled back to %s", latest.name)
|
||||
return True
|
||||
|
||||
def learn_from_session(self, session_data: dict) -> Dict[str, Any]:
|
||||
"""Full learning cycle: analyze → patch → apply.
|
||||
|
||||
Args:
|
||||
session_data: Session dict.
|
||||
|
||||
Returns:
|
||||
Summary of what was learned and applied.
|
||||
"""
|
||||
patterns = self.analyze_session(session_data)
|
||||
patches = self.generate_patches(patterns)
|
||||
applied = self.apply_patches(patches)
|
||||
|
||||
return {
|
||||
"patterns_detected": len(patterns),
|
||||
"patches_generated": len(patches),
|
||||
"patches_applied": applied,
|
||||
"patterns": [
|
||||
{"type": p.pattern_type, "frequency": p.frequency, "description": p.description}
|
||||
for p in patterns
|
||||
],
|
||||
}
|
||||
@@ -1,265 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Hermes MCP Server — expose hermes-agent tools to fleet peers.
|
||||
|
||||
Runs as a standalone MCP server that other agents can connect to
|
||||
and invoke hermes tools remotely.
|
||||
|
||||
Safe tools exposed:
|
||||
- terminal (safe commands only)
|
||||
- file_read, file_search
|
||||
- web_search, web_extract
|
||||
- session_search
|
||||
|
||||
NOT exposed (internal tools):
|
||||
- approval, delegate, memory, config
|
||||
|
||||
Usage:
|
||||
python -m tools.mcp_server --port 8081
|
||||
hermes mcp-server --port 8081
|
||||
python scripts/mcp_server.py --port 8081 --auth-key SECRET
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools safe to expose to other agents
|
||||
SAFE_TOOLS = {
|
||||
"terminal": {
|
||||
"name": "terminal",
|
||||
"description": "Execute safe shell commands. Dangerous commands are blocked.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Shell command to execute"},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
"file_read": {
|
||||
"name": "file_read",
|
||||
"description": "Read the contents of a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path to read"},
|
||||
"offset": {"type": "integer", "description": "Start line", "default": 1},
|
||||
"limit": {"type": "integer", "description": "Max lines", "default": 200},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
"file_search": {
|
||||
"name": "file_search",
|
||||
"description": "Search file contents using regex.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern"},
|
||||
"path": {"type": "string", "description": "Directory to search", "default": "."},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
},
|
||||
"web_search": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"session_search": {
|
||||
"name": "session_search",
|
||||
"description": "Search past conversation sessions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"limit": {"type": "integer", "description": "Max results", "default": 3},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Tools explicitly blocked
|
||||
BLOCKED_TOOLS = {
|
||||
"approval", "delegate", "memory", "config", "skill_install",
|
||||
"mcp_tool", "cronjob", "tts", "send_message",
|
||||
}
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""Simple MCP-compatible server for exposing hermes tools."""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 8081,
|
||||
auth_key: Optional[str] = None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._auth_key = auth_key or os.getenv("MCP_AUTH_KEY", "")
|
||||
|
||||
async def handle_tools_list(self, request: dict) -> dict:
|
||||
"""Return available tools."""
|
||||
tools = list(SAFE_TOOLS.values())
|
||||
return {"tools": tools}
|
||||
|
||||
async def handle_tools_call(self, request: dict) -> dict:
|
||||
"""Execute a tool call."""
|
||||
tool_name = request.get("name", "")
|
||||
arguments = request.get("arguments", {})
|
||||
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
return {"error": f"Tool '{tool_name}' is not exposed via MCP"}
|
||||
if tool_name not in SAFE_TOOLS:
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
try:
|
||||
result = await self._execute_tool(tool_name, arguments)
|
||||
return {"content": [{"type": "text", "text": str(result)}]}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _execute_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""Execute a tool and return result."""
|
||||
if tool_name == "terminal":
|
||||
import subprocess
|
||||
cmd = arguments.get("command", "")
|
||||
# Block dangerous commands
|
||||
from tools.approval import detect_dangerous_command
|
||||
is_dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
if is_dangerous:
|
||||
return f"BLOCKED: Dangerous command detected ({desc}). This tool only executes safe commands."
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
|
||||
return result.stdout or result.stderr or "(no output)"
|
||||
|
||||
elif tool_name == "file_read":
|
||||
path = arguments.get("path", "")
|
||||
offset = arguments.get("offset", 1)
|
||||
limit = arguments.get("limit", 200)
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
return "".join(lines[offset-1:offset-1+limit])
|
||||
|
||||
elif tool_name == "file_search":
|
||||
import re
|
||||
pattern = arguments.get("pattern", "")
|
||||
path = arguments.get("path", ".")
|
||||
results = []
|
||||
for p in Path(path).rglob("*.py"):
|
||||
try:
|
||||
content = p.read_text()
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if re.search(pattern, line, re.IGNORECASE):
|
||||
results.append(f"{p}:{i}: {line.strip()}")
|
||||
if len(results) >= 20:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if len(results) >= 20:
|
||||
break
|
||||
return "\n".join(results) or "No matches found"
|
||||
|
||||
elif tool_name == "web_search":
|
||||
try:
|
||||
from tools.web_tools import web_search
|
||||
return web_search(arguments.get("query", ""))
|
||||
except ImportError:
|
||||
return "Web search not available"
|
||||
|
||||
elif tool_name == "session_search":
|
||||
try:
|
||||
from tools.session_search_tool import session_search
|
||||
return session_search(
|
||||
query=arguments.get("query", ""),
|
||||
limit=arguments.get("limit", 3),
|
||||
)
|
||||
except ImportError:
|
||||
return "Session search not available"
|
||||
|
||||
return f"Tool {tool_name} not implemented"
|
||||
|
||||
async def start_http(self):
|
||||
"""Start HTTP server for MCP endpoints."""
|
||||
try:
|
||||
from aiohttp import web
|
||||
except ImportError:
|
||||
logger.error("aiohttp required: pip install aiohttp")
|
||||
return
|
||||
|
||||
app = web.Application()
|
||||
|
||||
async def handle_tools_list_route(request):
|
||||
if self._auth_key:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth != f"Bearer {self._auth_key}":
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
result = await self.handle_tools_list({})
|
||||
return web.json_response(result)
|
||||
|
||||
async def handle_tools_call_route(request):
|
||||
if self._auth_key:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth != f"Bearer {self._auth_key}":
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
body = await request.json()
|
||||
result = await self.handle_tools_call(body)
|
||||
return web.json_response(result)
|
||||
|
||||
async def handle_health(request):
|
||||
return web.json_response({"status": "ok", "tools": len(SAFE_TOOLS)})
|
||||
|
||||
app.router.add_get("/mcp/tools", handle_tools_list_route)
|
||||
app.router.add_post("/mcp/tools/call", handle_tools_call_route)
|
||||
app.router.add_get("/health", handle_health)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, self._host, self._port)
|
||||
await site.start()
|
||||
logger.info("MCP server on http://%s:%s", self._host, self._port)
|
||||
logger.info("Tools: %s", ", ".join(SAFE_TOOLS.keys()))
|
||||
if self._auth_key:
|
||||
logger.info("Auth: Bearer token required")
|
||||
else:
|
||||
logger.warning("Auth: No MCP_AUTH_KEY set — server is open")
|
||||
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Hermes MCP Server")
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8081)
|
||||
parser.add_argument("--auth-key", default=None, help="Bearer token for auth")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
|
||||
|
||||
server = MCPServer(host=args.host, port=args.port, auth_key=args.auth_key)
|
||||
print(f"Starting MCP server on http://{args.host}:{args.port}")
|
||||
print(f"Exposed tools: {', '.join(SAFE_TOOLS.keys())}")
|
||||
asyncio.run(server.start_http())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -201,31 +201,8 @@ def _get_command_timeout() -> int:
|
||||
|
||||
|
||||
def _get_vision_model() -> Optional[str]:
|
||||
"""Model for browser_vision (screenshot analysis — multimodal).
|
||||
|
||||
Priority:
|
||||
1. AUXILIARY_VISION_MODEL env var (explicit override)
|
||||
2. Gemma 4 (native multimodal, no model switching)
|
||||
3. Ollama local vision models
|
||||
4. None (fallback to text-only snapshot)
|
||||
"""
|
||||
# Explicit override always wins
|
||||
explicit = os.getenv("AUXILIARY_VISION_MODEL", "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# Prefer Gemma 4 (native multimodal — no separate vision model needed)
|
||||
gemma = os.getenv("GEMMA_VISION_MODEL", "").strip()
|
||||
if gemma:
|
||||
return gemma
|
||||
|
||||
# Check for Ollama vision models
|
||||
ollama_vision = os.getenv("OLLAMA_VISION_MODEL", "").strip()
|
||||
if ollama_vision:
|
||||
return ollama_vision
|
||||
|
||||
# Default: None (text-only fallback)
|
||||
return None
|
||||
"""Model for browser_vision (screenshot analysis — multimodal)."""
|
||||
return os.getenv("AUXILIARY_VISION_MODEL", "").strip() or None
|
||||
|
||||
|
||||
def _get_extraction_model() -> Optional[str]:
|
||||
|
||||
313
tools/risk_scorer.py
Normal file
313
tools/risk_scorer.py
Normal file
@@ -0,0 +1,313 @@
|
||||
"""Context-Aware Risk Scoring — ML-lite tier detection enhancement.
|
||||
|
||||
Enhances the existing approval.py dangerous-command detection with
|
||||
context-aware risk scoring. Instead of pure pattern matching, considers:
|
||||
|
||||
1. Path context: rm /tmp/x is safer than rm /etc/passwd
|
||||
2. Command context: chmod 777 on project dir vs system dir
|
||||
3. Recency: repeated dangerous commands increase risk
|
||||
4. Scope: commands affecting more files = higher risk
|
||||
|
||||
Usage:
|
||||
from tools.risk_scorer import score_action, RiskResult
|
||||
result = score_action("rm -rf /tmp/build")
|
||||
# result.tier = MEDIUM (not HIGH, because /tmp is safe)
|
||||
# result.confidence = 0.7
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
# Risk tiers (aligned with approval_tiers.py)
|
||||
class RiskTier(IntEnum):
|
||||
SAFE = 0
|
||||
LOW = 1
|
||||
MEDIUM = 2
|
||||
HIGH = 3
|
||||
CRITICAL = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class RiskResult:
|
||||
"""Result of risk scoring."""
|
||||
tier: RiskTier
|
||||
confidence: float # 0.0 to 1.0
|
||||
reasons: List[str] = field(default_factory=list)
|
||||
context_factors: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# --- Path risk assessment ---
|
||||
|
||||
SAFE_PATHS = {
|
||||
"/tmp", "/var/tmp", "/dev/shm",
|
||||
"~/.hermes/sessions", "~/.hermes/cache", "~/.hermes/logs",
|
||||
"/tmp/", "/var/tmp/",
|
||||
}
|
||||
|
||||
HIGH_RISK_PATHS = {
|
||||
"/etc", "/boot", "/usr/lib", "/usr/bin",
|
||||
"~/.ssh", "~/.gnupg",
|
||||
"/var/lib", "/opt",
|
||||
}
|
||||
|
||||
CRITICAL_PATHS = {
|
||||
"/", "/etc/passwd", "/etc/shadow", "/etc/sudoers",
|
||||
"~/.ssh/id_rsa", "~/.ssh/authorized_keys",
|
||||
"/boot/vmlinuz", "/dev/sda", "/dev/nvme",
|
||||
}
|
||||
|
||||
|
||||
def _extract_paths(command: str) -> List[str]:
|
||||
"""Extract file paths from a command."""
|
||||
paths = []
|
||||
# Match common path patterns
|
||||
for match in re.finditer(r'[/~][\w/.~-]+', command):
|
||||
paths.append(match.group())
|
||||
# Also match $HOME, $HERMES_HOME expansions
|
||||
for match in re.finditer(r'\$(?:HOME|HERMES_HOME|PWD)[/\w]*', command):
|
||||
paths.append(match.group())
|
||||
return paths
|
||||
|
||||
|
||||
def _classify_path(path: str) -> str:
|
||||
"""Classify a path as safe, high-risk, or critical."""
|
||||
path_lower = path.lower().replace("\\", "/")
|
||||
|
||||
for critical in CRITICAL_PATHS:
|
||||
if path_lower.startswith(critical.lower()):
|
||||
return "critical"
|
||||
|
||||
for high in HIGH_RISK_PATHS:
|
||||
if path_lower.startswith(high.lower()):
|
||||
return "high"
|
||||
|
||||
for safe in SAFE_PATHS:
|
||||
if path_lower.startswith(safe.lower()):
|
||||
return "safe"
|
||||
|
||||
# Unknown paths default to medium
|
||||
return "unknown"
|
||||
|
||||
|
||||
# --- Command risk modifiers ---
|
||||
|
||||
RISK_MODIFIERS = {
|
||||
# Flags that increase risk
|
||||
"-rf": 1.5,
|
||||
"-r": 1.2,
|
||||
"--force": 1.5,
|
||||
"--recursive": 1.2,
|
||||
"--no-preserve-root": 3.0,
|
||||
"-f": 1.3,
|
||||
"--hard": 1.5,
|
||||
"--force-push": 2.0,
|
||||
"-D": 1.4,
|
||||
# Flags that decrease risk
|
||||
"--dry-run": 0.1,
|
||||
"-n": 0.3,
|
||||
"--no-act": 0.1,
|
||||
"--interactive": 0.7,
|
||||
"-i": 0.7,
|
||||
}
|
||||
|
||||
|
||||
def _get_command_risk_modifier(command: str) -> float:
|
||||
"""Get risk modifier based on command flags."""
|
||||
modifier = 1.0
|
||||
for flag, mod in RISK_MODIFIERS.items():
|
||||
if flag in command:
|
||||
modifier *= mod
|
||||
return modifier
|
||||
|
||||
|
||||
# --- Scope assessment ---
|
||||
|
||||
def _assess_scope(command: str) -> float:
|
||||
"""Assess the scope of a command (how many files/systems affected)."""
|
||||
scope = 1.0
|
||||
|
||||
# Wildcards increase scope
|
||||
if "*" in command or "?" in command:
|
||||
scope *= 2.0
|
||||
|
||||
# Recursive operations increase scope
|
||||
if re.search(r'-r[f]?\b', command):
|
||||
scope *= 1.5
|
||||
|
||||
# find/xargs pipelines increase scope
|
||||
if "find" in command and ("exec" in command or "xargs" in command):
|
||||
scope *= 2.0
|
||||
|
||||
# Multiple targets increase scope
|
||||
paths = _extract_paths(command)
|
||||
if len(paths) > 2:
|
||||
scope *= 1.3
|
||||
|
||||
return min(scope, 5.0) # Cap at 5x
|
||||
|
||||
|
||||
# --- Recent command tracking ---
|
||||
|
||||
_recent_commands: List[Tuple[float, str]] = []
|
||||
_TRACK_WINDOW = 300 # 5 minutes
|
||||
|
||||
|
||||
def _track_command(command: str) -> float:
|
||||
"""Track command and return escalation factor based on recency."""
|
||||
now = time.time()
|
||||
|
||||
# Clean old entries
|
||||
global _recent_commands
|
||||
_recent_commands = [
|
||||
(ts, cmd) for ts, cmd in _recent_commands
|
||||
if now - ts < _TRACK_WINDOW
|
||||
]
|
||||
|
||||
# Check for repeated dangerous patterns
|
||||
escalation = 1.0
|
||||
for ts, recent_cmd in _recent_commands:
|
||||
# Same command repeated = escalating risk
|
||||
if recent_cmd == command:
|
||||
escalation += 0.2
|
||||
# Similar commands = moderate escalation
|
||||
elif _commands_similar(command, recent_cmd):
|
||||
escalation += 0.1
|
||||
|
||||
_recent_commands.append((now, command))
|
||||
return min(escalation, 3.0) # Cap at 3x
|
||||
|
||||
|
||||
def _commands_similar(cmd1: str, cmd2: str) -> bool:
|
||||
"""Check if two commands are structurally similar."""
|
||||
# Extract command name
|
||||
name1 = cmd1.split()[0] if cmd1.split() else ""
|
||||
name2 = cmd2.split()[0] if cmd2.split() else ""
|
||||
return name1 == name2
|
||||
|
||||
|
||||
# --- Main scoring function ---
|
||||
|
||||
# Base tier mapping from command name
|
||||
COMMAND_BASE_TIERS = {
|
||||
"rm": RiskTier.HIGH,
|
||||
"chmod": RiskTier.MEDIUM,
|
||||
"chown": RiskTier.HIGH,
|
||||
"mkfs": RiskTier.CRITICAL,
|
||||
"dd": RiskTier.HIGH,
|
||||
"kill": RiskTier.HIGH,
|
||||
"pkill": RiskTier.HIGH,
|
||||
"systemctl": RiskTier.HIGH,
|
||||
"git": RiskTier.LOW,
|
||||
"sed": RiskTier.LOW,
|
||||
"cp": RiskTier.LOW,
|
||||
"mv": RiskTier.LOW,
|
||||
"python3": RiskTier.LOW,
|
||||
"pip": RiskTier.LOW,
|
||||
"npm": RiskTier.LOW,
|
||||
"docker": RiskTier.MEDIUM,
|
||||
"ansible": RiskTier.HIGH,
|
||||
}
|
||||
|
||||
|
||||
def score_action(action: str, context: Optional[Dict[str, Any]] = None) -> RiskResult:
|
||||
"""Score an action's risk level with context awareness.
|
||||
|
||||
Considers:
|
||||
- Command base risk
|
||||
- Path context (safe vs critical paths)
|
||||
- Command flags (force, recursive, dry-run)
|
||||
- Scope (wildcards, multiple targets)
|
||||
- Recency (repeated commands escalate)
|
||||
|
||||
Returns:
|
||||
RiskResult with tier, confidence, and reasons.
|
||||
"""
|
||||
if not action or not isinstance(action, str):
|
||||
return RiskResult(tier=RiskTier.SAFE, confidence=1.0, reasons=["empty input"])
|
||||
|
||||
parts = action.strip().split()
|
||||
if not parts:
|
||||
return RiskResult(tier=RiskTier.SAFE, confidence=1.0, reasons=["empty command"])
|
||||
|
||||
cmd_name = parts[0].split("/")[-1] # Extract command name
|
||||
|
||||
# Base tier from command name
|
||||
base_tier = COMMAND_BASE_TIERS.get(cmd_name, RiskTier.SAFE)
|
||||
|
||||
# Path risk assessment
|
||||
paths = _extract_paths(action)
|
||||
max_path_risk = "safe"
|
||||
for path in paths:
|
||||
path_risk = _classify_path(path)
|
||||
risk_order = {"safe": 0, "unknown": 1, "high": 2, "critical": 3}
|
||||
if risk_order.get(path_risk, 0) > risk_order.get(max_path_risk, 0):
|
||||
max_path_risk = path_risk
|
||||
|
||||
# Calculate final tier
|
||||
reasons = []
|
||||
|
||||
# Path-based tier adjustment
|
||||
if max_path_risk == "critical":
|
||||
base_tier = RiskTier.CRITICAL
|
||||
reasons.append(f"Critical path detected: {paths[0] if paths else 'unknown'}")
|
||||
elif max_path_risk == "high":
|
||||
if base_tier.value < RiskTier.HIGH.value:
|
||||
base_tier = RiskTier.HIGH
|
||||
reasons.append(f"High-risk path: {paths[0] if paths else 'unknown'}")
|
||||
elif max_path_risk == "safe":
|
||||
# Downgrade if all paths are safe
|
||||
if base_tier.value > RiskTier.MEDIUM.value:
|
||||
base_tier = RiskTier.MEDIUM
|
||||
reasons.append("Safe path context — risk downgraded")
|
||||
|
||||
# Apply modifiers
|
||||
modifier = _get_command_risk_modifier(action)
|
||||
scope = _assess_scope(action)
|
||||
recency = _track_command(action)
|
||||
|
||||
# Check for dry-run (overrides everything)
|
||||
if "--dry-run" in action or "-n " in action:
|
||||
return RiskResult(
|
||||
tier=RiskTier.SAFE,
|
||||
confidence=0.95,
|
||||
reasons=["dry-run mode — no actual changes"],
|
||||
context_factors={"dry_run": True},
|
||||
)
|
||||
|
||||
# Calculate confidence
|
||||
confidence = 0.8 # Base confidence
|
||||
|
||||
if max_path_risk == "safe":
|
||||
confidence = 0.9
|
||||
elif max_path_risk == "unknown":
|
||||
confidence = 0.6
|
||||
elif max_path_risk == "critical":
|
||||
confidence = 0.95
|
||||
|
||||
# Reasons
|
||||
if modifier > 1.5:
|
||||
reasons.append(f"Force/recursive flags (modifier: {modifier:.1f}x)")
|
||||
if scope > 1.5:
|
||||
reasons.append(f"Wide scope (wildcards/multiple targets, {scope:.1f}x)")
|
||||
if recency > 1.2:
|
||||
reasons.append(f"Repeated command pattern ({recency:.1f}x escalation)")
|
||||
|
||||
if not reasons:
|
||||
reasons.append(f"Command '{cmd_name}' classified as {base_tier.name}")
|
||||
|
||||
return RiskResult(
|
||||
tier=base_tier,
|
||||
confidence=round(confidence, 2),
|
||||
reasons=reasons,
|
||||
context_factors={
|
||||
"path_risk": max_path_risk,
|
||||
"modifier": round(modifier, 2),
|
||||
"scope": round(scope, 2),
|
||||
"recency": round(recency, 2),
|
||||
"paths": paths,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user