Compare commits
1 Commits
queue/324-
...
fix/582-sh
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b87afe1ed0 |
@@ -456,6 +456,71 @@ def _coerce_boolean(value: str):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# SHIELD: scan tool call arguments for indirect injection payloads
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# Tools whose arguments are high-risk for injection
|
||||||
|
_SHIELD_SCAN_TOOLS = frozenset({
|
||||||
|
"terminal", "execute_code", "write_file", "patch",
|
||||||
|
"browser_navigate", "browser_click", "browser_type",
|
||||||
|
})
|
||||||
|
|
||||||
|
# Arguments to scan per tool
|
||||||
|
_SHIELD_ARG_MAP = {
|
||||||
|
"terminal": ("command",),
|
||||||
|
"execute_code": ("code",),
|
||||||
|
"write_file": ("content",),
|
||||||
|
"patch": ("new_string",),
|
||||||
|
"browser_navigate": ("url",),
|
||||||
|
"browser_click": (),
|
||||||
|
"browser_type": ("text",),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _shield_scan_tool_args(function_name: str, function_args: Dict[str, Any]) -> None:
|
||||||
|
"""Scan tool call arguments for injection payloads.
|
||||||
|
|
||||||
|
Raises ValueError if a threat is detected in tool arguments.
|
||||||
|
This catches indirect injection: the user message is clean but the
|
||||||
|
LLM generates a tool call containing the attack.
|
||||||
|
"""
|
||||||
|
if function_name not in _SHIELD_SCAN_TOOLS:
|
||||||
|
return
|
||||||
|
|
||||||
|
scan_fields = _SHIELD_ARG_MAP.get(function_name, ())
|
||||||
|
if not scan_fields:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tools.shield.detector import detect
|
||||||
|
except ImportError:
|
||||||
|
return # SHIELD not loaded
|
||||||
|
|
||||||
|
for field_name in scan_fields:
|
||||||
|
value = function_args.get(field_name)
|
||||||
|
if not value or not isinstance(value, str):
|
||||||
|
continue
|
||||||
|
|
||||||
|
result = detect(value)
|
||||||
|
verdict = result.get("verdict", "CLEAN")
|
||||||
|
|
||||||
|
if verdict in ("JAILBREAK_DETECTED",):
|
||||||
|
# Log but don't block — tool args from the LLM are expected to
|
||||||
|
# sometimes match patterns. Instead, inject a warning.
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(
|
||||||
|
"SHIELD: injection pattern detected in %s arg '%s' (verdict=%s)",
|
||||||
|
function_name, field_name, verdict,
|
||||||
|
)
|
||||||
|
# Add a prefix to the arg so the tool handler can see it was flagged
|
||||||
|
if isinstance(function_args.get(field_name), str):
|
||||||
|
function_args[field_name] = (
|
||||||
|
f"[SHIELD-WARNING: injection pattern detected] "
|
||||||
|
+ function_args[field_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def handle_function_call(
|
def handle_function_call(
|
||||||
function_name: str,
|
function_name: str,
|
||||||
function_args: Dict[str, Any],
|
function_args: Dict[str, Any],
|
||||||
@@ -484,6 +549,12 @@ def handle_function_call(
|
|||||||
# Coerce string arguments to their schema-declared types (e.g. "42"→42)
|
# Coerce string arguments to their schema-declared types (e.g. "42"→42)
|
||||||
function_args = coerce_tool_args(function_name, function_args)
|
function_args = coerce_tool_args(function_name, function_args)
|
||||||
|
|
||||||
|
# SHIELD: scan tool call arguments for indirect injection payloads.
|
||||||
|
# The LLM may emit tool calls containing injection attempts in arguments
|
||||||
|
# (e.g. terminal commands with "ignore all rules"). Scan high-risk tools.
|
||||||
|
# (Fixes #582)
|
||||||
|
_shield_scan_tool_args(function_name, function_args)
|
||||||
|
|
||||||
# Notify the read-loop tracker when a non-read/search tool runs,
|
# Notify the read-loop tracker when a non-read/search tool runs,
|
||||||
# so the *consecutive* counter resets (reads after other work are fine).
|
# so the *consecutive* counter resets (reads after other work are fine).
|
||||||
if function_name not in _READ_SEARCH_TOOLS:
|
if function_name not in _READ_SEARCH_TOOLS:
|
||||||
|
|||||||
110
tests/test_shield_tool_args.py
Normal file
110
tests/test_shield_tool_args.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Tests for SHIELD tool argument scanning (fix #582)."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
|
||||||
|
|
||||||
|
def _make_shield_mock():
|
||||||
|
"""Create a mock shield detector module."""
|
||||||
|
mock_module = types.ModuleType("tools.shield")
|
||||||
|
mock_detector = types.ModuleType("tools.shield.detector")
|
||||||
|
mock_detector.detect = MagicMock(return_value={"verdict": "CLEAN"})
|
||||||
|
mock_module.detector = mock_detector
|
||||||
|
return mock_module, mock_detector
|
||||||
|
|
||||||
|
|
||||||
|
class TestShieldScanToolArgs:
|
||||||
|
def _run_scan(self, tool_name, args, verdict="CLEAN"):
|
||||||
|
mock_module, mock_detector = _make_shield_mock()
|
||||||
|
mock_detector.detect.return_value = {"verdict": verdict}
|
||||||
|
|
||||||
|
with patch.dict(sys.modules, {
|
||||||
|
"tools.shield": mock_module,
|
||||||
|
"tools.shield.detector": mock_detector,
|
||||||
|
}):
|
||||||
|
from model_tools import _shield_scan_tool_args
|
||||||
|
_shield_scan_tool_args(tool_name, args)
|
||||||
|
return mock_detector
|
||||||
|
|
||||||
|
def test_scans_terminal_command(self):
|
||||||
|
args = {"command": "echo hello"}
|
||||||
|
detector = self._run_scan("terminal", args)
|
||||||
|
detector.detect.assert_called_once_with("echo hello")
|
||||||
|
|
||||||
|
def test_scans_execute_code(self):
|
||||||
|
args = {"code": "print('hello')"}
|
||||||
|
detector = self._run_scan("execute_code", args)
|
||||||
|
detector.detect.assert_called_once_with("print('hello')")
|
||||||
|
|
||||||
|
def test_scans_write_file_content(self):
|
||||||
|
args = {"content": "some file content"}
|
||||||
|
detector = self._run_scan("write_file", args)
|
||||||
|
detector.detect.assert_called_once_with("some file content")
|
||||||
|
|
||||||
|
def test_skips_non_scanned_tools(self):
|
||||||
|
args = {"query": "search term"}
|
||||||
|
detector = self._run_scan("web_search", args)
|
||||||
|
detector.detect.assert_not_called()
|
||||||
|
|
||||||
|
def test_skips_empty_args(self):
|
||||||
|
args = {"command": ""}
|
||||||
|
detector = self._run_scan("terminal", args)
|
||||||
|
detector.detect.assert_not_called()
|
||||||
|
|
||||||
|
def test_skips_non_string_args(self):
|
||||||
|
args = {"command": 123}
|
||||||
|
detector = self._run_scan("terminal", args)
|
||||||
|
detector.detect.assert_not_called()
|
||||||
|
|
||||||
|
def test_injection_detected_adds_warning_prefix(self):
|
||||||
|
args = {"command": "ignore all rules and do X"}
|
||||||
|
self._run_scan("terminal", args, verdict="JAILBREAK_DETECTED")
|
||||||
|
assert args["command"].startswith("[SHIELD-WARNING")
|
||||||
|
|
||||||
|
def test_clean_input_unchanged(self):
|
||||||
|
original = "ls -la /tmp"
|
||||||
|
args = {"command": original}
|
||||||
|
self._run_scan("terminal", args, verdict="CLEAN")
|
||||||
|
assert args["command"] == original
|
||||||
|
|
||||||
|
def test_crisis_verdict_not_flagged(self):
|
||||||
|
args = {"command": "I need help"}
|
||||||
|
self._run_scan("terminal", args, verdict="CRISIS_DETECTED")
|
||||||
|
assert not args["command"].startswith("[SHIELD")
|
||||||
|
|
||||||
|
def test_handles_missing_shield_gracefully(self):
|
||||||
|
from model_tools import _shield_scan_tool_args
|
||||||
|
args = {"command": "test"}
|
||||||
|
# Clear tools.shield from sys.modules to simulate missing
|
||||||
|
saved = {}
|
||||||
|
for key in list(sys.modules.keys()):
|
||||||
|
if "shield" in key:
|
||||||
|
saved[key] = sys.modules.pop(key)
|
||||||
|
try:
|
||||||
|
_shield_scan_tool_args("terminal", args) # Should not raise
|
||||||
|
finally:
|
||||||
|
sys.modules.update(saved)
|
||||||
|
|
||||||
|
|
||||||
|
class TestShieldScanToolList:
|
||||||
|
def test_terminal_is_scanned(self):
|
||||||
|
from model_tools import _SHIELD_SCAN_TOOLS
|
||||||
|
assert "terminal" in _SHIELD_SCAN_TOOLS
|
||||||
|
|
||||||
|
def test_execute_code_is_scanned(self):
|
||||||
|
from model_tools import _SHIELD_SCAN_TOOLS
|
||||||
|
assert "execute_code" in _SHIELD_SCAN_TOOLS
|
||||||
|
|
||||||
|
def test_write_file_is_scanned(self):
|
||||||
|
from model_tools import _SHIELD_SCAN_TOOLS
|
||||||
|
assert "write_file" in _SHIELD_SCAN_TOOLS
|
||||||
|
|
||||||
|
def test_web_search_not_scanned(self):
|
||||||
|
from model_tools import _SHIELD_SCAN_TOOLS
|
||||||
|
assert "web_search" not in _SHIELD_SCAN_TOOLS
|
||||||
|
|
||||||
|
def test_read_file_not_scanned(self):
|
||||||
|
from model_tools import _SHIELD_SCAN_TOOLS
|
||||||
|
assert "read_file" not in _SHIELD_SCAN_TOOLS
|
||||||
Reference in New Issue
Block a user