Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
97c075f2fe |
@@ -1,302 +0,0 @@
|
||||
"""Self-Modifying Prompt Engine — agent learns from its own failures.
|
||||
|
||||
Analyzes session transcripts, identifies failure patterns, and generates
|
||||
prompt patches to prevent future failures.
|
||||
|
||||
The loop: fail → analyze → rewrite → retry → verify improvement.
|
||||
|
||||
Usage:
|
||||
from agent.self_modify import PromptLearner
|
||||
learner = PromptLearner()
|
||||
patches = learner.analyze_session(session_id)
|
||||
learner.apply_patches(patches)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
HERMES_HOME = Path(os.getenv("HERMES_HOME", Path.home() / ".hermes"))
|
||||
PATCHES_DIR = HERMES_HOME / "prompt_patches"
|
||||
ROLLBACK_DIR = HERMES_HOME / "prompt_rollback"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FailurePattern:
|
||||
"""A detected failure pattern in session transcripts."""
|
||||
pattern_type: str # retry_loop, timeout, error_hallucination, context_loss
|
||||
description: str
|
||||
frequency: int
|
||||
example_messages: List[str] = field(default_factory=list)
|
||||
suggested_fix: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class PromptPatch:
|
||||
"""A modification to the system prompt based on failure analysis."""
|
||||
id: str
|
||||
failure_type: str
|
||||
original_rule: str
|
||||
new_rule: str
|
||||
confidence: float
|
||||
applied_at: Optional[float] = None
|
||||
reverted: bool = False
|
||||
|
||||
|
||||
# Failure detection patterns
|
||||
FAILURE_SIGNALS = {
|
||||
"retry_loop": {
|
||||
"patterns": [
|
||||
r"(?i)retry(?:ing)?\s*(?:attempt|again)",
|
||||
r"(?i)failed.*retrying",
|
||||
r"(?i)error.*again",
|
||||
r"(?i)attempt\s+\d+\s*(?:of|/)\s*\d+",
|
||||
],
|
||||
"description": "Agent stuck in retry loop",
|
||||
},
|
||||
"timeout": {
|
||||
"patterns": [
|
||||
r"(?i)timed?\s*out",
|
||||
r"(?i)deadline\s+exceeded",
|
||||
r"(?i)took\s+(?:too\s+)?long",
|
||||
],
|
||||
"description": "Operation timed out",
|
||||
},
|
||||
"hallucination": {
|
||||
"patterns": [
|
||||
r"(?i)i\s+(?:don't|do\s+not)\s+(?:have|see|find)\s+(?:any|that|this)\s+(?:information|data|file)",
|
||||
r"(?i)the\s+file\s+doesn't\s+exist",
|
||||
r"(?i)i\s+(?:made|invented|fabricated)\s+(?:that\s+up|this)",
|
||||
],
|
||||
"description": "Agent hallucinated or fabricated information",
|
||||
},
|
||||
"context_loss": {
|
||||
"patterns": [
|
||||
r"(?i)i\s+(?:don't|do\s+not)\s+(?:remember|recall|know)\s+(?:what|where|when|how)",
|
||||
r"(?i)could\s+you\s+remind\s+me",
|
||||
r"(?i)what\s+were\s+we\s+(?:doing|working|talking)\s+(?:on|about)",
|
||||
],
|
||||
"description": "Agent lost context from earlier in conversation",
|
||||
},
|
||||
"tool_failure": {
|
||||
"patterns": [
|
||||
r"(?i)tool\s+(?:call|execution)\s+failed",
|
||||
r"(?i)command\s+not\s+found",
|
||||
r"(?i)permission\s+denied",
|
||||
r"(?i)no\s+such\s+file",
|
||||
],
|
||||
"description": "Tool execution failed",
|
||||
},
|
||||
}
|
||||
|
||||
# Prompt improvement templates
|
||||
PROMPT_FIXES = {
|
||||
"retry_loop": (
|
||||
"If an operation fails more than twice, stop retrying. "
|
||||
"Report the failure and ask the user for guidance. "
|
||||
"Do not enter retry loops — they waste tokens."
|
||||
),
|
||||
"timeout": (
|
||||
"For operations that may take long, set a timeout and report "
|
||||
"progress. If an operation takes more than 30 seconds, report "
|
||||
"what you've done so far and ask if you should continue."
|
||||
),
|
||||
"hallucination": (
|
||||
"If you cannot find information, say 'I don't know' or "
|
||||
"'I couldn't find that.' Never fabricate information. "
|
||||
"If a file doesn't exist, say so — don't guess its contents."
|
||||
),
|
||||
"context_loss": (
|
||||
"When you need context from earlier in the conversation, "
|
||||
"use session_search to find it. Don't ask the user to repeat themselves."
|
||||
),
|
||||
"tool_failure": (
|
||||
"If a tool fails, check the error message and try a different approach. "
|
||||
"Don't retry the exact same command — diagnose first."
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class PromptLearner:
|
||||
"""Analyze session transcripts and generate prompt improvements."""
|
||||
|
||||
def __init__(self):
|
||||
PATCHES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
ROLLBACK_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def analyze_session(self, session_data: dict) -> List[FailurePattern]:
|
||||
"""Analyze a session for failure patterns.
|
||||
|
||||
Args:
|
||||
session_data: Session dict with 'messages' list.
|
||||
|
||||
Returns:
|
||||
List of detected failure patterns.
|
||||
"""
|
||||
messages = session_data.get("messages", [])
|
||||
patterns_found: Dict[str, FailurePattern] = {}
|
||||
|
||||
for msg in messages:
|
||||
content = str(msg.get("content", ""))
|
||||
role = msg.get("role", "")
|
||||
|
||||
# Only analyze assistant messages and tool results
|
||||
if role not in ("assistant", "tool"):
|
||||
continue
|
||||
|
||||
for failure_type, config in FAILURE_SIGNALS.items():
|
||||
for pattern in config["patterns"]:
|
||||
if re.search(pattern, content):
|
||||
if failure_type not in patterns_found:
|
||||
patterns_found[failure_type] = FailurePattern(
|
||||
pattern_type=failure_type,
|
||||
description=config["description"],
|
||||
frequency=0,
|
||||
suggested_fix=PROMPT_FIXES.get(failure_type, ""),
|
||||
)
|
||||
patterns_found[failure_type].frequency += 1
|
||||
if len(patterns_found[failure_type].example_messages) < 3:
|
||||
patterns_found[failure_type].example_messages.append(
|
||||
content[:200]
|
||||
)
|
||||
break # One match per message per type is enough
|
||||
|
||||
return list(patterns_found.values())
|
||||
|
||||
def generate_patches(self, patterns: List[FailurePattern],
|
||||
min_confidence: float = 0.7) -> List[PromptPatch]:
|
||||
"""Generate prompt patches from failure patterns.
|
||||
|
||||
Args:
|
||||
patterns: Detected failure patterns.
|
||||
min_confidence: Minimum confidence to generate a patch.
|
||||
|
||||
Returns:
|
||||
List of prompt patches.
|
||||
"""
|
||||
patches = []
|
||||
for pattern in patterns:
|
||||
# Confidence based on frequency
|
||||
if pattern.frequency >= 3:
|
||||
confidence = 0.9
|
||||
elif pattern.frequency >= 2:
|
||||
confidence = 0.75
|
||||
else:
|
||||
confidence = 0.5
|
||||
|
||||
if confidence < min_confidence:
|
||||
continue
|
||||
|
||||
if not pattern.suggested_fix:
|
||||
continue
|
||||
|
||||
patch = PromptPatch(
|
||||
id=f"{pattern.pattern_type}-{int(time.time())}",
|
||||
failure_type=pattern.pattern_type,
|
||||
original_rule="(missing — no existing rule for this pattern)",
|
||||
new_rule=pattern.suggested_fix,
|
||||
confidence=confidence,
|
||||
)
|
||||
patches.append(patch)
|
||||
|
||||
return patches
|
||||
|
||||
def apply_patches(self, patches: List[PromptPatch],
|
||||
prompt_path: Optional[str] = None) -> int:
|
||||
"""Apply patches to the system prompt.
|
||||
|
||||
Args:
|
||||
patches: Patches to apply.
|
||||
prompt_path: Path to prompt file (default: ~/.hermes/system_prompt.md)
|
||||
|
||||
Returns:
|
||||
Number of patches applied.
|
||||
"""
|
||||
if prompt_path is None:
|
||||
prompt_path = str(HERMES_HOME / "system_prompt.md")
|
||||
|
||||
prompt_file = Path(prompt_path)
|
||||
|
||||
# Backup current prompt
|
||||
if prompt_file.exists():
|
||||
backup = ROLLBACK_DIR / f"{prompt_file.name}.{int(time.time())}.bak"
|
||||
backup.write_text(prompt_file.read_text())
|
||||
|
||||
# Read current prompt
|
||||
current = prompt_file.read_text() if prompt_file.exists() else ""
|
||||
|
||||
# Apply patches
|
||||
applied = 0
|
||||
additions = []
|
||||
for patch in patches:
|
||||
if patch.new_rule not in current:
|
||||
additions.append(f"\n## Auto-learned: {patch.failure_type}\n{patch.new_rule}")
|
||||
patch.applied_at = time.time()
|
||||
applied += 1
|
||||
|
||||
if additions:
|
||||
new_content = current + "\n".join(additions)
|
||||
prompt_file.write_text(new_content)
|
||||
|
||||
# Log patches
|
||||
patches_file = PATCHES_DIR / f"patches-{int(time.time())}.json"
|
||||
with open(patches_file, "w") as f:
|
||||
json.dump([p.__dict__ for p in patches], f, indent=2, default=str)
|
||||
|
||||
logger.info("Applied %d prompt patches", applied)
|
||||
return applied
|
||||
|
||||
def rollback_last(self, prompt_path: Optional[str] = None) -> bool:
|
||||
"""Rollback to the most recent backup.
|
||||
|
||||
Args:
|
||||
prompt_path: Path to prompt file.
|
||||
|
||||
Returns:
|
||||
True if rollback succeeded.
|
||||
"""
|
||||
if prompt_path is None:
|
||||
prompt_path = str(HERMES_HOME / "system_prompt.md")
|
||||
|
||||
backups = sorted(ROLLBACK_DIR.glob("*.bak"), reverse=True)
|
||||
if not backups:
|
||||
logger.warning("No backups to rollback to")
|
||||
return False
|
||||
|
||||
latest = backups[0]
|
||||
Path(prompt_path).write_text(latest.read_text())
|
||||
logger.info("Rolled back to %s", latest.name)
|
||||
return True
|
||||
|
||||
def learn_from_session(self, session_data: dict) -> Dict[str, Any]:
|
||||
"""Full learning cycle: analyze → patch → apply.
|
||||
|
||||
Args:
|
||||
session_data: Session dict.
|
||||
|
||||
Returns:
|
||||
Summary of what was learned and applied.
|
||||
"""
|
||||
patterns = self.analyze_session(session_data)
|
||||
patches = self.generate_patches(patterns)
|
||||
applied = self.apply_patches(patches)
|
||||
|
||||
return {
|
||||
"patterns_detected": len(patterns),
|
||||
"patches_generated": len(patches),
|
||||
"patches_applied": applied,
|
||||
"patterns": [
|
||||
{"type": p.pattern_type, "frequency": p.frequency, "description": p.description}
|
||||
for p in patterns
|
||||
],
|
||||
}
|
||||
@@ -1,265 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Hermes MCP Server — expose hermes-agent tools to fleet peers.
|
||||
|
||||
Runs as a standalone MCP server that other agents can connect to
|
||||
and invoke hermes tools remotely.
|
||||
|
||||
Safe tools exposed:
|
||||
- terminal (safe commands only)
|
||||
- file_read, file_search
|
||||
- web_search, web_extract
|
||||
- session_search
|
||||
|
||||
NOT exposed (internal tools):
|
||||
- approval, delegate, memory, config
|
||||
|
||||
Usage:
|
||||
python -m tools.mcp_server --port 8081
|
||||
hermes mcp-server --port 8081
|
||||
python scripts/mcp_server.py --port 8081 --auth-key SECRET
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools safe to expose to other agents
|
||||
SAFE_TOOLS = {
|
||||
"terminal": {
|
||||
"name": "terminal",
|
||||
"description": "Execute safe shell commands. Dangerous commands are blocked.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string", "description": "Shell command to execute"},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
"file_read": {
|
||||
"name": "file_read",
|
||||
"description": "Read the contents of a file.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "File path to read"},
|
||||
"offset": {"type": "integer", "description": "Start line", "default": 1},
|
||||
"limit": {"type": "integer", "description": "Max lines", "default": 200},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
"file_search": {
|
||||
"name": "file_search",
|
||||
"description": "Search file contents using regex.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {"type": "string", "description": "Regex pattern"},
|
||||
"path": {"type": "string", "description": "Directory to search", "default": "."},
|
||||
},
|
||||
"required": ["pattern"],
|
||||
},
|
||||
},
|
||||
"web_search": {
|
||||
"name": "web_search",
|
||||
"description": "Search the web for information.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
"session_search": {
|
||||
"name": "session_search",
|
||||
"description": "Search past conversation sessions.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string", "description": "Search query"},
|
||||
"limit": {"type": "integer", "description": "Max results", "default": 3},
|
||||
},
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Tools explicitly blocked
|
||||
BLOCKED_TOOLS = {
|
||||
"approval", "delegate", "memory", "config", "skill_install",
|
||||
"mcp_tool", "cronjob", "tts", "send_message",
|
||||
}
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""Simple MCP-compatible server for exposing hermes tools."""
|
||||
|
||||
def __init__(self, host: str = "127.0.0.1", port: int = 8081,
|
||||
auth_key: Optional[str] = None):
|
||||
self._host = host
|
||||
self._port = port
|
||||
self._auth_key = auth_key or os.getenv("MCP_AUTH_KEY", "")
|
||||
|
||||
async def handle_tools_list(self, request: dict) -> dict:
|
||||
"""Return available tools."""
|
||||
tools = list(SAFE_TOOLS.values())
|
||||
return {"tools": tools}
|
||||
|
||||
async def handle_tools_call(self, request: dict) -> dict:
|
||||
"""Execute a tool call."""
|
||||
tool_name = request.get("name", "")
|
||||
arguments = request.get("arguments", {})
|
||||
|
||||
if tool_name in BLOCKED_TOOLS:
|
||||
return {"error": f"Tool '{tool_name}' is not exposed via MCP"}
|
||||
if tool_name not in SAFE_TOOLS:
|
||||
return {"error": f"Unknown tool: {tool_name}"}
|
||||
|
||||
try:
|
||||
result = await self._execute_tool(tool_name, arguments)
|
||||
return {"content": [{"type": "text", "text": str(result)}]}
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _execute_tool(self, tool_name: str, arguments: dict) -> str:
|
||||
"""Execute a tool and return result."""
|
||||
if tool_name == "terminal":
|
||||
import subprocess
|
||||
cmd = arguments.get("command", "")
|
||||
# Block dangerous commands
|
||||
from tools.approval import detect_dangerous_command
|
||||
is_dangerous, _, desc = detect_dangerous_command(cmd)
|
||||
if is_dangerous:
|
||||
return f"BLOCKED: Dangerous command detected ({desc}). This tool only executes safe commands."
|
||||
result = subprocess.run(cmd, shell=True, capture_output=True, text=True, timeout=30)
|
||||
return result.stdout or result.stderr or "(no output)"
|
||||
|
||||
elif tool_name == "file_read":
|
||||
path = arguments.get("path", "")
|
||||
offset = arguments.get("offset", 1)
|
||||
limit = arguments.get("limit", 200)
|
||||
with open(path) as f:
|
||||
lines = f.readlines()
|
||||
return "".join(lines[offset-1:offset-1+limit])
|
||||
|
||||
elif tool_name == "file_search":
|
||||
import re
|
||||
pattern = arguments.get("pattern", "")
|
||||
path = arguments.get("path", ".")
|
||||
results = []
|
||||
for p in Path(path).rglob("*.py"):
|
||||
try:
|
||||
content = p.read_text()
|
||||
for i, line in enumerate(content.split("\n"), 1):
|
||||
if re.search(pattern, line, re.IGNORECASE):
|
||||
results.append(f"{p}:{i}: {line.strip()}")
|
||||
if len(results) >= 20:
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if len(results) >= 20:
|
||||
break
|
||||
return "\n".join(results) or "No matches found"
|
||||
|
||||
elif tool_name == "web_search":
|
||||
try:
|
||||
from tools.web_tools import web_search
|
||||
return web_search(arguments.get("query", ""))
|
||||
except ImportError:
|
||||
return "Web search not available"
|
||||
|
||||
elif tool_name == "session_search":
|
||||
try:
|
||||
from tools.session_search_tool import session_search
|
||||
return session_search(
|
||||
query=arguments.get("query", ""),
|
||||
limit=arguments.get("limit", 3),
|
||||
)
|
||||
except ImportError:
|
||||
return "Session search not available"
|
||||
|
||||
return f"Tool {tool_name} not implemented"
|
||||
|
||||
async def start_http(self):
|
||||
"""Start HTTP server for MCP endpoints."""
|
||||
try:
|
||||
from aiohttp import web
|
||||
except ImportError:
|
||||
logger.error("aiohttp required: pip install aiohttp")
|
||||
return
|
||||
|
||||
app = web.Application()
|
||||
|
||||
async def handle_tools_list_route(request):
|
||||
if self._auth_key:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth != f"Bearer {self._auth_key}":
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
result = await self.handle_tools_list({})
|
||||
return web.json_response(result)
|
||||
|
||||
async def handle_tools_call_route(request):
|
||||
if self._auth_key:
|
||||
auth = request.headers.get("Authorization", "")
|
||||
if auth != f"Bearer {self._auth_key}":
|
||||
return web.json_response({"error": "Unauthorized"}, status=401)
|
||||
body = await request.json()
|
||||
result = await self.handle_tools_call(body)
|
||||
return web.json_response(result)
|
||||
|
||||
async def handle_health(request):
|
||||
return web.json_response({"status": "ok", "tools": len(SAFE_TOOLS)})
|
||||
|
||||
app.router.add_get("/mcp/tools", handle_tools_list_route)
|
||||
app.router.add_post("/mcp/tools/call", handle_tools_call_route)
|
||||
app.router.add_get("/health", handle_health)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, self._host, self._port)
|
||||
await site.start()
|
||||
logger.info("MCP server on http://%s:%s", self._host, self._port)
|
||||
logger.info("Tools: %s", ", ".join(SAFE_TOOLS.keys()))
|
||||
if self._auth_key:
|
||||
logger.info("Auth: Bearer token required")
|
||||
else:
|
||||
logger.warning("Auth: No MCP_AUTH_KEY set — server is open")
|
||||
|
||||
try:
|
||||
await asyncio.Event().wait()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Hermes MCP Server")
|
||||
parser.add_argument("--host", default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8081)
|
||||
parser.add_argument("--auth-key", default=None, help="Bearer token for auth")
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s [%(name)s] %(levelname)s: %(message)s")
|
||||
|
||||
server = MCPServer(host=args.host, port=args.port, auth_key=args.auth_key)
|
||||
print(f"Starting MCP server on http://{args.host}:{args.port}")
|
||||
print(f"Exposed tools: {', '.join(SAFE_TOOLS.keys())}")
|
||||
asyncio.run(server.start_http())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
122
tests/test_credential_redaction.py
Normal file
122
tests/test_credential_redaction.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Tests for credential redaction — Issue #839."""
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from tools.credential_redaction import (
|
||||
redact_credentials, should_auto_mask, mask_config_values,
|
||||
redact_tool_output, RedactionResult
|
||||
)
|
||||
|
||||
|
||||
class TestRedactCredentials:
|
||||
def test_openai_key(self):
|
||||
text = "API key: sk-abc123def456ghi789jkl012mno345pqr678stu901vwx"
|
||||
result = redact_credentials(text)
|
||||
assert result.was_redacted
|
||||
assert "sk-abc" not in result.text
|
||||
assert "[REDACTED" in result.text
|
||||
|
||||
def test_github_pat(self):
|
||||
text = "token: ghp_1234567890abcdefghijklmnopqrstuvwxyz"
|
||||
result = redact_credentials(text)
|
||||
assert result.was_redacted
|
||||
assert "ghp_" not in result.text
|
||||
|
||||
def test_bearer_token(self):
|
||||
text = "Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9"
|
||||
result = redact_credentials(text)
|
||||
assert result.was_redacted
|
||||
assert "Bearer eyJ" not in result.text
|
||||
|
||||
def test_password_assignment(self):
|
||||
text = 'password: "supersecret123"'
|
||||
result = redact_credentials(text)
|
||||
assert result.was_redacted
|
||||
|
||||
def test_clean_text(self):
|
||||
text = "Hello world, no credentials here"
|
||||
result = redact_credentials(text)
|
||||
assert not result.was_redacted
|
||||
assert result.text == text
|
||||
|
||||
def test_empty_text(self):
|
||||
result = redact_credentials("")
|
||||
assert not result.was_redacted
|
||||
|
||||
|
||||
class TestShouldAutoMask:
|
||||
def test_env_file(self):
|
||||
assert should_auto_mask(".env") == True
|
||||
|
||||
def test_config_file(self):
|
||||
assert should_auto_mask("config.yaml") == True
|
||||
|
||||
def test_token_file(self):
|
||||
assert should_auto_mask("gitea_token") == True
|
||||
|
||||
def test_normal_file(self):
|
||||
assert should_auto_mask("readme.md") == False
|
||||
|
||||
|
||||
class TestMaskConfigValues:
|
||||
def test_env_api_key(self):
|
||||
text = "API_KEY=sk-abc123def456"
|
||||
result = mask_config_values(text)
|
||||
assert "sk-abc" not in result
|
||||
assert "[REDACTED]" in result
|
||||
|
||||
def test_yaml_token(self):
|
||||
text = 'token: "ghp_1234567890"'
|
||||
result = mask_config_values(text)
|
||||
assert "ghp_" not in result
|
||||
assert "[REDACTED]" in result
|
||||
|
||||
def test_preserves_structure(self):
|
||||
text = "API_KEY=secret\nOTHER=value"
|
||||
result = mask_config_values(text)
|
||||
assert "OTHER=value" in result # Non-credential preserved
|
||||
|
||||
|
||||
class TestRedactToolOutput:
|
||||
def test_string_output(self):
|
||||
output = "Result: sk-abc123def456ghi789jkl012mno345pqr678stu901vwx"
|
||||
redacted, notice = redact_tool_output("file_read", output)
|
||||
assert "sk-abc123" not in redacted
|
||||
assert notice is not None
|
||||
|
||||
def test_dict_output(self):
|
||||
output = {"content": "token: ghp_1234567890abcdefghijklmnopqrstuvwxyz"}
|
||||
redacted, notice = redact_tool_output("file_read", output)
|
||||
assert "ghp_" not in redacted["content"]
|
||||
|
||||
def test_clean_output(self):
|
||||
output = "No credentials here"
|
||||
redacted, notice = redact_tool_output("file_read", output)
|
||||
assert redacted == output
|
||||
assert notice is None
|
||||
|
||||
|
||||
class TestRedactionResult:
|
||||
def test_notice_singular(self):
|
||||
result = RedactionResult("redacted", "original", [{"pattern_name": "test"}])
|
||||
assert "1 credential pattern" in result.notice()
|
||||
|
||||
def test_notice_plural(self):
|
||||
result = RedactionResult("redacted", "original", [
|
||||
{"pattern_name": "test1"},
|
||||
{"pattern_name": "test2"},
|
||||
])
|
||||
assert "2 credential patterns" in result.notice()
|
||||
|
||||
def test_to_dict(self):
|
||||
result = RedactionResult("redacted", "original", [{"pattern_name": "test"}])
|
||||
d = result.to_dict()
|
||||
assert d["redacted"] == True
|
||||
assert d["count"] == 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import pytest
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -201,31 +201,8 @@ def _get_command_timeout() -> int:
|
||||
|
||||
|
||||
def _get_vision_model() -> Optional[str]:
|
||||
"""Model for browser_vision (screenshot analysis — multimodal).
|
||||
|
||||
Priority:
|
||||
1. AUXILIARY_VISION_MODEL env var (explicit override)
|
||||
2. Gemma 4 (native multimodal, no model switching)
|
||||
3. Ollama local vision models
|
||||
4. None (fallback to text-only snapshot)
|
||||
"""
|
||||
# Explicit override always wins
|
||||
explicit = os.getenv("AUXILIARY_VISION_MODEL", "").strip()
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# Prefer Gemma 4 (native multimodal — no separate vision model needed)
|
||||
gemma = os.getenv("GEMMA_VISION_MODEL", "").strip()
|
||||
if gemma:
|
||||
return gemma
|
||||
|
||||
# Check for Ollama vision models
|
||||
ollama_vision = os.getenv("OLLAMA_VISION_MODEL", "").strip()
|
||||
if ollama_vision:
|
||||
return ollama_vision
|
||||
|
||||
# Default: None (text-only fallback)
|
||||
return None
|
||||
"""Model for browser_vision (screenshot analysis — multimodal)."""
|
||||
return os.getenv("AUXILIARY_VISION_MODEL", "").strip() or None
|
||||
|
||||
|
||||
def _get_extraction_model() -> Optional[str]:
|
||||
|
||||
269
tools/credential_redaction.py
Normal file
269
tools/credential_redaction.py
Normal file
@@ -0,0 +1,269 @@
|
||||
"""Credential Redaction — Poka-yoke for tool outputs.
|
||||
|
||||
Blocks silent credential exposure by redacting API keys, tokens, and
|
||||
passwords from tool outputs before they enter agent context.
|
||||
|
||||
Issue #839: Poka-yoke: Block silent credential exposure in tool outputs
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Audit log path
|
||||
_AUDIT_DIR = Path.home() / ".hermes" / "audit"
|
||||
_AUDIT_LOG = _AUDIT_DIR / "redactions.jsonl"
|
||||
|
||||
# Credential patterns — order matters (most specific first)
|
||||
_CREDENTIAL_PATTERNS = [
|
||||
# API keys
|
||||
(r'sk-[a-zA-Z0-9]{20,}', '[REDACTED: OpenAI-style API key]'),
|
||||
(r'sk-ant-[a-zA-Z0-9-]{20,}', '[REDACTED: Anthropic API key]'),
|
||||
(r'ghp_[a-zA-Z0-9]{36}', '[REDACTED: GitHub PAT]'),
|
||||
(r'gho_[a-zA-Z0-9]{36}', '[REDACTED: GitHub OAuth token]'),
|
||||
(r'github_pat_[a-zA-Z0-9_]{82}', '[REDACTED: GitHub fine-grained PAT]'),
|
||||
(r'glpat-[a-zA-Z0-9-]{20,}', '[REDACTED: GitLab PAT]'),
|
||||
(r'syt_[a-zA-Z0-9_-]{40,}', '[REDACTED: Matrix access token]'),
|
||||
(r'xoxb-[0-9]{10,}-[a-zA-Z0-9]{20,}', '[REDACTED: Slack bot token]'),
|
||||
(r'xoxp-[0-9]{10,}-[a-zA-Z0-9]{20,}', '[REDACTED: Slack user token]'),
|
||||
|
||||
# Bearer tokens
|
||||
(r'Bearer\s+[a-zA-Z0-9_.-]{20,}', '[REDACTED: Bearer token]'),
|
||||
|
||||
# Generic tokens/passwords in assignments
|
||||
(r'(?:token|api_key|api_key|secret|password|passwd|pwd)\s*[:=]\s*["\']?([a-zA-Z0-9_.-]{8,})["\']?', '[REDACTED: credential]'),
|
||||
|
||||
# Environment variable assignments
|
||||
(r'(?:export\s+)?(?:TOKEN|KEY|SECRET|PASSWORD|API_KEY)\s*=\s*["\']?([a-zA-Z0-9_.-]{8,})["\']?', '[REDACTED: env credential]'),
|
||||
|
||||
# Base64 encoded credentials (high entropy strings)
|
||||
(r'(?:authorization|auth)\s*[:=]\s*(?:basic|bearer)\s+[a-zA-Z0-9+/=]{20,}', '[REDACTED: auth header]'),
|
||||
|
||||
# AWS credentials
|
||||
(r'AKIA[0-9A-Z]{16}', '[REDACTED: AWS access key]'),
|
||||
(r'(?<![A-Z0-9])[A-Za-z0-9/+=]{40}(?![A-Z0-9])', None), # Only match near context
|
||||
|
||||
# Private keys
|
||||
(r'-----BEGIN (?:RSA |EC |OPENSSH )?PRIVATE KEY-----', '[REDACTED: private key block]'),
|
||||
]
|
||||
|
||||
|
||||
class RedactionResult:
|
||||
"""Result of credential redaction."""
|
||||
|
||||
def __init__(self, text: str, original: str, redactions: List[Dict[str, Any]]):
|
||||
self.text = text
|
||||
self.original = original
|
||||
self.redactions = redactions
|
||||
|
||||
@property
|
||||
def was_redacted(self) -> bool:
|
||||
return len(self.redactions) > 0
|
||||
|
||||
@property
|
||||
def count(self) -> int:
|
||||
return len(self.redactions)
|
||||
|
||||
def notice(self) -> str:
|
||||
"""Generate compact redaction notice."""
|
||||
if not self.was_redacted:
|
||||
return ""
|
||||
return f"[REDACTED: {self.count} credential pattern{'s' if self.count > 1 else ''} found]"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"redacted": self.was_redacted,
|
||||
"count": self.count,
|
||||
"notice": self.notice(),
|
||||
"patterns": [r["pattern_name"] for r in self.redactions],
|
||||
}
|
||||
|
||||
|
||||
def redact_credentials(text: str, source: str = "unknown") -> RedactionResult:
|
||||
"""Redact credentials from text.
|
||||
|
||||
Args:
|
||||
text: Text to redact
|
||||
source: Source identifier for audit logging
|
||||
|
||||
Returns:
|
||||
RedactionResult with redacted text and metadata
|
||||
"""
|
||||
if not text:
|
||||
return RedactionResult(text, text, [])
|
||||
|
||||
redactions = []
|
||||
result = text
|
||||
|
||||
for pattern, replacement in _CREDENTIAL_PATTERNS:
|
||||
if replacement is None:
|
||||
continue # Skip conditional patterns
|
||||
|
||||
matches = list(re.finditer(pattern, result, re.IGNORECASE))
|
||||
for match in matches:
|
||||
redactions.append({
|
||||
"pattern_name": replacement,
|
||||
"position": match.start(),
|
||||
"length": len(match.group()),
|
||||
"source": source,
|
||||
"timestamp": time.time(),
|
||||
})
|
||||
|
||||
result = re.sub(pattern, replacement, result, flags=re.IGNORECASE)
|
||||
|
||||
redaction_result = RedactionResult(result, text, redactions)
|
||||
|
||||
# Log to audit trail
|
||||
if redaction_result.was_redacted:
|
||||
_log_redaction(redaction_result, source)
|
||||
|
||||
return redaction_result
|
||||
|
||||
|
||||
def _log_redaction(result: RedactionResult, source: str) -> None:
|
||||
"""Log redaction event to audit trail."""
|
||||
try:
|
||||
_AUDIT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
entry = {
|
||||
"timestamp": time.time(),
|
||||
"source": source,
|
||||
"count": result.count,
|
||||
"patterns": [r["pattern_name"] for r in result.redactions],
|
||||
}
|
||||
with open(_AUDIT_LOG, "a") as f:
|
||||
f.write(json.dumps(entry) + "\n")
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to log redaction: {e}")
|
||||
|
||||
|
||||
def should_auto_mask(file_path: str) -> bool:
|
||||
"""Check if file should have credentials auto-masked."""
|
||||
path_lower = file_path.lower()
|
||||
sensitive_patterns = [
|
||||
".env", "config", "token", "secret", "credential",
|
||||
"key", "auth", "password", ".pem", ".key",
|
||||
]
|
||||
return any(p in path_lower for p in sensitive_patterns)
|
||||
|
||||
|
||||
def mask_config_values(text: str) -> str:
|
||||
"""Mask credential values in config/env files while preserving structure.
|
||||
|
||||
Transforms:
|
||||
API_KEY=sk-abc123 → API_KEY=[REDACTED]
|
||||
token: "ghp_xyz" → token: "[REDACTED]"
|
||||
"""
|
||||
lines = text.split("\n")
|
||||
result = []
|
||||
|
||||
for line in lines:
|
||||
# Match KEY=VALUE patterns
|
||||
match = re.match(r'^(\s*(?:export\s+)?[A-Z_][A-Z0-9_]*)\s*=\s*(.*)', line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = match.group(2).strip()
|
||||
|
||||
# Check if key looks credential-like
|
||||
key_lower = key.lower()
|
||||
if any(p in key_lower for p in ["key", "token", "secret", "password", "auth"]):
|
||||
if value and not value.startswith("[REDACTED]"):
|
||||
# Preserve quotes
|
||||
if value.startswith('"') and value.endswith('"'):
|
||||
result.append(f'{key}="[REDACTED]"')
|
||||
elif value.startswith("'") and value.endswith("'"):
|
||||
result.append(f"{key}='[REDACTED]'")
|
||||
else:
|
||||
result.append(f"{key}=[REDACTED]")
|
||||
continue
|
||||
|
||||
# Match YAML-style key: value
|
||||
match = re.match(r'^(\s*[a-z_][a-z0-9_]*)\s*:\s*["\']?(.*?)["\']?\s*$', line)
|
||||
if match:
|
||||
key = match.group(1)
|
||||
value = match.group(2).strip()
|
||||
|
||||
key_lower = key.lower()
|
||||
if any(p in key_lower for p in ["key", "token", "secret", "password", "auth"]):
|
||||
if value and not value.startswith("[REDACTED]"):
|
||||
result.append(f'{key}: "[REDACTED]"')
|
||||
continue
|
||||
|
||||
result.append(line)
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def redact_tool_output(
|
||||
tool_name: str,
|
||||
output: Any,
|
||||
source: str = None,
|
||||
) -> Tuple[Any, Optional[str]]:
|
||||
"""Redact credentials from tool output.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool
|
||||
output: Tool output (string or dict)
|
||||
source: Source identifier (defaults to tool_name)
|
||||
|
||||
Returns:
|
||||
Tuple of (redacted_output, notice)
|
||||
"""
|
||||
source = source or tool_name
|
||||
|
||||
if isinstance(output, str):
|
||||
result = redact_credentials(output, source)
|
||||
if result.was_redacted:
|
||||
return result.text, result.notice()
|
||||
return output, None
|
||||
|
||||
if isinstance(output, dict):
|
||||
# Redact string values in dict
|
||||
redacted = {}
|
||||
notices = []
|
||||
for key, value in output.items():
|
||||
if isinstance(value, str):
|
||||
r, n = redact_tool_output(tool_name, value, f"{source}.{key}")
|
||||
redacted[key] = r
|
||||
if n:
|
||||
notices.append(n)
|
||||
else:
|
||||
redacted[key] = value
|
||||
|
||||
notice = "; ".join(notices) if notices else None
|
||||
return redacted, notice
|
||||
|
||||
# Non-string, non-dict: pass through
|
||||
return output, None
|
||||
|
||||
|
||||
def get_redaction_stats() -> Dict[str, Any]:
|
||||
"""Get redaction statistics from audit log."""
|
||||
stats = {
|
||||
"total_redactions": 0,
|
||||
"by_source": {},
|
||||
"by_pattern": {},
|
||||
}
|
||||
|
||||
if not _AUDIT_LOG.exists():
|
||||
return stats
|
||||
|
||||
try:
|
||||
with open(_AUDIT_LOG, "r") as f:
|
||||
for line in f:
|
||||
entry = json.loads(line.strip())
|
||||
stats["total_redactions"] += entry.get("count", 0)
|
||||
|
||||
source = entry.get("source", "unknown")
|
||||
stats["by_source"][source] = stats["by_source"].get(source, 0) + 1
|
||||
|
||||
for pattern in entry.get("patterns", []):
|
||||
stats["by_pattern"][pattern] = stats["by_pattern"].get(pattern, 0) + 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return stats
|
||||
Reference in New Issue
Block a user