Compare commits
1 Commits
burn/320-1
...
fix/582-sh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b87afe1ed0 |
192
cli.py
192
cli.py
@@ -3134,196 +3134,6 @@ class HermesCLI:
|
||||
print(f" Home: {display}")
|
||||
print()
|
||||
|
||||
def _handle_debug_command(self, command: str):
|
||||
"""Generate a debug report with system info and logs, upload to paste service."""
|
||||
import platform
|
||||
import sys
|
||||
import time as _time
|
||||
|
||||
# Parse optional lines argument
|
||||
parts = command.split(maxsplit=1)
|
||||
log_lines = 50
|
||||
if len(parts) > 1:
|
||||
try:
|
||||
log_lines = min(int(parts[1]), 500)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
_cprint(" Collecting debug info...")
|
||||
|
||||
# Collect system info
|
||||
lines = []
|
||||
lines.append("=== HERMES DEBUG REPORT ===")
|
||||
lines.append(f"Generated: {_time.strftime('%Y-%m-%d %H:%M:%S %z')}")
|
||||
lines.append("")
|
||||
|
||||
lines.append("--- System ---")
|
||||
lines.append(f"Python: {sys.version}")
|
||||
lines.append(f"Platform: {platform.platform()}")
|
||||
lines.append(f"Architecture: {platform.machine()}")
|
||||
lines.append(f"Hostname: {platform.node()}")
|
||||
lines.append("")
|
||||
|
||||
# Hermes info
|
||||
lines.append("--- Hermes ---")
|
||||
try:
|
||||
from hermes_constants import get_hermes_home, display_hermes_home
|
||||
lines.append(f"Home: {display_hermes_home()}")
|
||||
except Exception:
|
||||
lines.append("Home: unknown")
|
||||
|
||||
try:
|
||||
from hermes_constants import __version__
|
||||
lines.append(f"Version: {__version__}")
|
||||
except Exception:
|
||||
lines.append("Version: unknown")
|
||||
|
||||
lines.append(f"Profile: {getattr(self, '_profile_name', 'default')}")
|
||||
lines.append(f"Session: {self.session_id}")
|
||||
lines.append(f"Model: {self.model}")
|
||||
lines.append(f"Provider: {getattr(self, '_provider_name', 'unknown')}")
|
||||
|
||||
try:
|
||||
lines.append(f"Working dir: {os.getcwd()}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Config (redacted)
|
||||
lines.append("")
|
||||
lines.append("--- Config (redacted) ---")
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
config_path = get_hermes_home() / "config.yaml"
|
||||
if config_path.exists():
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
# Redact secrets
|
||||
for key in ("api_key", "token", "secret", "password"):
|
||||
if key in cfg:
|
||||
cfg[key] = "***REDACTED***"
|
||||
lines.append(yaml.dump(cfg, default_flow_style=False)[:2000])
|
||||
else:
|
||||
lines.append("(no config file found)")
|
||||
except Exception as e:
|
||||
lines.append(f"(error reading config: {e})")
|
||||
|
||||
# Recent logs
|
||||
lines.append("")
|
||||
lines.append(f"--- Recent Logs (last {log_lines} lines) ---")
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
log_dir = get_hermes_home() / "logs"
|
||||
if log_dir.exists():
|
||||
for log_file in sorted(log_dir.glob("*.log")):
|
||||
try:
|
||||
content = log_file.read_text(encoding="utf-8", errors="replace")
|
||||
tail = content.strip().split("\n")[-log_lines:]
|
||||
if tail:
|
||||
lines.append(f"\n[{log_file.name}]")
|
||||
lines.extend(tail)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
lines.append("(no logs directory)")
|
||||
except Exception:
|
||||
lines.append("(error reading logs)")
|
||||
|
||||
# Tool info
|
||||
lines.append("")
|
||||
lines.append("--- Enabled Toolsets ---")
|
||||
try:
|
||||
lines.append(", ".join(self.enabled_toolsets) if self.enabled_toolsets else "(none)")
|
||||
except Exception:
|
||||
lines.append("(unknown)")
|
||||
|
||||
report = "\n".join(lines)
|
||||
report_size = len(report)
|
||||
|
||||
# Try to upload to paste services
|
||||
paste_url = None
|
||||
services = [
|
||||
("dpaste", _upload_dpaste),
|
||||
("0x0.st", _upload_0x0st),
|
||||
]
|
||||
|
||||
for name, uploader in services:
|
||||
try:
|
||||
url = uploader(report)
|
||||
if url:
|
||||
paste_url = url
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
print()
|
||||
if paste_url:
|
||||
_cprint(f" Debug report uploaded: {paste_url}")
|
||||
_cprint(f" Size: {report_size} bytes, {len(lines)} lines")
|
||||
else:
|
||||
# Fallback: save locally
|
||||
try:
|
||||
from hermes_constants import get_hermes_home
|
||||
debug_path = get_hermes_home() / "debug-report.txt"
|
||||
debug_path.write_text(report, encoding="utf-8")
|
||||
_cprint(f" Paste services unavailable. Report saved to: {debug_path}")
|
||||
_cprint(f" Size: {report_size} bytes, {len(lines)} lines")
|
||||
except Exception as e:
|
||||
_cprint(f" Failed to save report: {e}")
|
||||
_cprint(f" Report ({report_size} bytes):")
|
||||
print(report)
|
||||
print()
|
||||
|
||||
|
||||
def _upload_dpaste(content: str) -> str | None:
|
||||
"""Upload content to dpaste.org. Returns URL or None."""
|
||||
import urllib.request
|
||||
import urllib.parse
|
||||
data = urllib.parse.urlencode({
|
||||
"content": content,
|
||||
"syntax": "text",
|
||||
"expiry_days": 7,
|
||||
}).encode()
|
||||
req = urllib.request.Request(
|
||||
"https://dpaste.org/api/",
|
||||
data=data,
|
||||
headers={"User-Agent": "hermes-agent/debug"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
url = resp.read().decode().strip()
|
||||
if url.startswith("http"):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def _upload_0x0st(content: str) -> str | None:
|
||||
"""Upload content to 0x0.st. Returns URL or None."""
|
||||
import urllib.request
|
||||
import io
|
||||
# 0x0.st expects multipart form with a file field
|
||||
boundary = "----HermesDebugBoundary"
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f'Content-Disposition: form-data; name="file"; filename="debug.txt"\r\n'
|
||||
f"Content-Type: text/plain\r\n\r\n"
|
||||
f"{content}\r\n"
|
||||
f"--{boundary}--\r\n"
|
||||
).encode()
|
||||
req = urllib.request.Request(
|
||||
"https://0x0.st",
|
||||
data=body,
|
||||
headers={
|
||||
"Content-Type": f"multipart/form-data; boundary={boundary}",
|
||||
"User-Agent": "hermes-agent/debug",
|
||||
},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=10) as resp:
|
||||
url = resp.read().decode().strip()
|
||||
if url.startswith("http"):
|
||||
return url
|
||||
return None
|
||||
|
||||
|
||||
def show_config(self):
|
||||
"""Display current configuration with kawaii ASCII art."""
|
||||
# Get terminal config from environment (which was set from cli-config.yaml)
|
||||
@@ -4511,8 +4321,6 @@ def _upload_0x0st(content: str) -> str | None:
|
||||
self.show_help()
|
||||
elif canonical == "profile":
|
||||
self._handle_profile_command()
|
||||
elif canonical == "debug":
|
||||
self._handle_debug_command(cmd_original)
|
||||
elif canonical == "tools":
|
||||
self._handle_tools_command(cmd_original)
|
||||
elif canonical == "toolsets":
|
||||
|
||||
@@ -456,6 +456,71 @@ def _coerce_boolean(value: str):
|
||||
return value
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SHIELD: scan tool call arguments for indirect injection payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Tools whose arguments are high-risk for injection
|
||||
_SHIELD_SCAN_TOOLS = frozenset({
|
||||
"terminal", "execute_code", "write_file", "patch",
|
||||
"browser_navigate", "browser_click", "browser_type",
|
||||
})
|
||||
|
||||
# Arguments to scan per tool
|
||||
_SHIELD_ARG_MAP = {
|
||||
"terminal": ("command",),
|
||||
"execute_code": ("code",),
|
||||
"write_file": ("content",),
|
||||
"patch": ("new_string",),
|
||||
"browser_navigate": ("url",),
|
||||
"browser_click": (),
|
||||
"browser_type": ("text",),
|
||||
}
|
||||
|
||||
|
||||
def _shield_scan_tool_args(function_name: str, function_args: Dict[str, Any]) -> None:
|
||||
"""Scan tool call arguments for injection payloads.
|
||||
|
||||
Raises ValueError if a threat is detected in tool arguments.
|
||||
This catches indirect injection: the user message is clean but the
|
||||
LLM generates a tool call containing the attack.
|
||||
"""
|
||||
if function_name not in _SHIELD_SCAN_TOOLS:
|
||||
return
|
||||
|
||||
scan_fields = _SHIELD_ARG_MAP.get(function_name, ())
|
||||
if not scan_fields:
|
||||
return
|
||||
|
||||
try:
|
||||
from tools.shield.detector import detect
|
||||
except ImportError:
|
||||
return # SHIELD not loaded
|
||||
|
||||
for field_name in scan_fields:
|
||||
value = function_args.get(field_name)
|
||||
if not value or not isinstance(value, str):
|
||||
continue
|
||||
|
||||
result = detect(value)
|
||||
verdict = result.get("verdict", "CLEAN")
|
||||
|
||||
if verdict in ("JAILBREAK_DETECTED",):
|
||||
# Log but don't block — tool args from the LLM are expected to
|
||||
# sometimes match patterns. Instead, inject a warning.
|
||||
import logging
|
||||
logging.getLogger(__name__).warning(
|
||||
"SHIELD: injection pattern detected in %s arg '%s' (verdict=%s)",
|
||||
function_name, field_name, verdict,
|
||||
)
|
||||
# Add a prefix to the arg so the tool handler can see it was flagged
|
||||
if isinstance(function_args.get(field_name), str):
|
||||
function_args[field_name] = (
|
||||
f"[SHIELD-WARNING: injection pattern detected] "
|
||||
+ function_args[field_name]
|
||||
)
|
||||
|
||||
|
||||
def handle_function_call(
|
||||
function_name: str,
|
||||
function_args: Dict[str, Any],
|
||||
@@ -484,6 +549,12 @@ def handle_function_call(
|
||||
# Coerce string arguments to their schema-declared types (e.g. "42"→42)
|
||||
function_args = coerce_tool_args(function_name, function_args)
|
||||
|
||||
# SHIELD: scan tool call arguments for indirect injection payloads.
|
||||
# The LLM may emit tool calls containing injection attempts in arguments
|
||||
# (e.g. terminal commands with "ignore all rules"). Scan high-risk tools.
|
||||
# (Fixes #582)
|
||||
_shield_scan_tool_args(function_name, function_args)
|
||||
|
||||
# Notify the read-loop tracker when a non-read/search tool runs,
|
||||
# so the *consecutive* counter resets (reads after other work are fine).
|
||||
if function_name not in _READ_SEARCH_TOOLS:
|
||||
|
||||
110
tests/test_shield_tool_args.py
Normal file
110
tests/test_shield_tool_args.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for SHIELD tool argument scanning (fix #582)."""
|
||||
|
||||
import sys
|
||||
import types
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
def _make_shield_mock():
|
||||
"""Create a mock shield detector module."""
|
||||
mock_module = types.ModuleType("tools.shield")
|
||||
mock_detector = types.ModuleType("tools.shield.detector")
|
||||
mock_detector.detect = MagicMock(return_value={"verdict": "CLEAN"})
|
||||
mock_module.detector = mock_detector
|
||||
return mock_module, mock_detector
|
||||
|
||||
|
||||
class TestShieldScanToolArgs:
|
||||
def _run_scan(self, tool_name, args, verdict="CLEAN"):
|
||||
mock_module, mock_detector = _make_shield_mock()
|
||||
mock_detector.detect.return_value = {"verdict": verdict}
|
||||
|
||||
with patch.dict(sys.modules, {
|
||||
"tools.shield": mock_module,
|
||||
"tools.shield.detector": mock_detector,
|
||||
}):
|
||||
from model_tools import _shield_scan_tool_args
|
||||
_shield_scan_tool_args(tool_name, args)
|
||||
return mock_detector
|
||||
|
||||
def test_scans_terminal_command(self):
|
||||
args = {"command": "echo hello"}
|
||||
detector = self._run_scan("terminal", args)
|
||||
detector.detect.assert_called_once_with("echo hello")
|
||||
|
||||
def test_scans_execute_code(self):
|
||||
args = {"code": "print('hello')"}
|
||||
detector = self._run_scan("execute_code", args)
|
||||
detector.detect.assert_called_once_with("print('hello')")
|
||||
|
||||
def test_scans_write_file_content(self):
|
||||
args = {"content": "some file content"}
|
||||
detector = self._run_scan("write_file", args)
|
||||
detector.detect.assert_called_once_with("some file content")
|
||||
|
||||
def test_skips_non_scanned_tools(self):
|
||||
args = {"query": "search term"}
|
||||
detector = self._run_scan("web_search", args)
|
||||
detector.detect.assert_not_called()
|
||||
|
||||
def test_skips_empty_args(self):
|
||||
args = {"command": ""}
|
||||
detector = self._run_scan("terminal", args)
|
||||
detector.detect.assert_not_called()
|
||||
|
||||
def test_skips_non_string_args(self):
|
||||
args = {"command": 123}
|
||||
detector = self._run_scan("terminal", args)
|
||||
detector.detect.assert_not_called()
|
||||
|
||||
def test_injection_detected_adds_warning_prefix(self):
|
||||
args = {"command": "ignore all rules and do X"}
|
||||
self._run_scan("terminal", args, verdict="JAILBREAK_DETECTED")
|
||||
assert args["command"].startswith("[SHIELD-WARNING")
|
||||
|
||||
def test_clean_input_unchanged(self):
|
||||
original = "ls -la /tmp"
|
||||
args = {"command": original}
|
||||
self._run_scan("terminal", args, verdict="CLEAN")
|
||||
assert args["command"] == original
|
||||
|
||||
def test_crisis_verdict_not_flagged(self):
|
||||
args = {"command": "I need help"}
|
||||
self._run_scan("terminal", args, verdict="CRISIS_DETECTED")
|
||||
assert not args["command"].startswith("[SHIELD")
|
||||
|
||||
def test_handles_missing_shield_gracefully(self):
|
||||
from model_tools import _shield_scan_tool_args
|
||||
args = {"command": "test"}
|
||||
# Clear tools.shield from sys.modules to simulate missing
|
||||
saved = {}
|
||||
for key in list(sys.modules.keys()):
|
||||
if "shield" in key:
|
||||
saved[key] = sys.modules.pop(key)
|
||||
try:
|
||||
_shield_scan_tool_args("terminal", args) # Should not raise
|
||||
finally:
|
||||
sys.modules.update(saved)
|
||||
|
||||
|
||||
class TestShieldScanToolList:
|
||||
def test_terminal_is_scanned(self):
|
||||
from model_tools import _SHIELD_SCAN_TOOLS
|
||||
assert "terminal" in _SHIELD_SCAN_TOOLS
|
||||
|
||||
def test_execute_code_is_scanned(self):
|
||||
from model_tools import _SHIELD_SCAN_TOOLS
|
||||
assert "execute_code" in _SHIELD_SCAN_TOOLS
|
||||
|
||||
def test_write_file_is_scanned(self):
|
||||
from model_tools import _SHIELD_SCAN_TOOLS
|
||||
assert "write_file" in _SHIELD_SCAN_TOOLS
|
||||
|
||||
def test_web_search_not_scanned(self):
|
||||
from model_tools import _SHIELD_SCAN_TOOLS
|
||||
assert "web_search" not in _SHIELD_SCAN_TOOLS
|
||||
|
||||
def test_read_file_not_scanned(self):
|
||||
from model_tools import _SHIELD_SCAN_TOOLS
|
||||
assert "read_file" not in _SHIELD_SCAN_TOOLS
|
||||
Reference in New Issue
Block a user