Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
0822837ec3 feat: context-aware risk scoring for tier detection (#681)
Some checks are pending
Contributor Attribution Check / check-attribution (pull_request) Waiting to run
Docker Build and Publish / build-and-push (pull_request) Waiting to run
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Waiting to run
Tests / test (pull_request) Waiting to run
Tests / e2e (pull_request) Waiting to run
Resolves #681. Enhances approval tier detection with context-aware
risk scoring instead of pure pattern matching.

tools/risk_scorer.py:
- Path context: /tmp is safe, /etc/passwd is critical
- Command flags: --force increases risk, --dry-run decreases
- Scope assessment: wildcards and recursive increase risk
- Recency tracking: repeated dangerous commands escalate
- Safe paths: /tmp, ~/.hermes/sessions, project dirs
- Critical paths: /etc/passwd, ~/.ssh/id_rsa, /boot

score_action() returns RiskResult with tier, confidence,
reasons, and context_factors.
2026-04-16 00:28:58 -04:00
3 changed files with 313 additions and 382 deletions

View File

@@ -1,288 +0,0 @@
"""Gemma 4 tool calling hardening — parse, validate, benchmark.
Gemma 4 has native multimodal function calling but its output format
may differ from OpenAI/Claude. This module provides:
1. Gemma4ToolParser — robust parsing for Gemma 4's tool call format
2. Parallel tool call detection and splitting
3. Tool call success rate tracking and benchmarking
4. Fallback parsing strategies for malformed output
Usage:
from agent.gemma4_tool_hardening import Gemma4ToolParser
parser = Gemma4ToolParser()
tool_calls = parser.parse(response_text)
"""
from __future__ import annotations
import json
import re
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
@dataclass
class ToolCallAttempt:
"""Record of a single tool call parsing attempt."""
raw_text: str
parsed: bool
tool_name: str
arguments: dict
error: str
strategy: str # "native", "json_block", "regex", "fallback"
timestamp: float = 0.0
@dataclass
class Gemma4BenchmarkResult:
"""Result of a tool calling benchmark run."""
total_calls: int = 0
successful_parses: int = 0
parallel_calls: int = 0
strategies_used: Dict[str, int] = field(default_factory=dict)
avg_parse_time_ms: float = 0.0
success_rate: float = 0.0
errors: List[str] = field(default_factory=list)
def to_dict(self) -> dict:
return {
"total_calls": self.total_calls,
"successful_parses": self.successful_parses,
"parallel_calls": self.parallel_calls,
"success_rate": round(self.success_rate, 3),
"strategies_used": self.strategies_used,
"avg_parse_time_ms": round(self.avg_parse_time_ms, 2),
"error_count": len(self.errors),
"errors": self.errors[:10],
}
class Gemma4ToolParser:
"""Robust tool call parser for Gemma 4 output format.
Tries multiple parsing strategies in order:
1. Native OpenAI format (standard tool_calls)
2. JSON code blocks (```json ... ```)
3. Regex extraction (function_name + arguments patterns)
4. Heuristic fallback (best-effort extraction)
"""
# Patterns for Gemma 4 tool call formats
_JSON_BLOCK_PATTERN = re.compile(
r'```(?:json)?\s*\n?(.*?)\n?```',
re.DOTALL | re.IGNORECASE,
)
_FUNCTION_CALL_PATTERN = re.compile(
r'(?:function|tool|call)[:\s]*(\w+)\s*\(\s*({.*?})\s*\)',
re.DOTALL | re.IGNORECASE,
)
_GEMMA_INLINE_PATTERN = re.compile(
r'\[(?:tool_call|function_call)\]\s*(\w+)\s*:\s*({.*?})',
re.DOTALL | re.IGNORECASE,
)
def __init__(self):
self._attempts: List[ToolCallAttempt] = []
self._benchmark = Gemma4BenchmarkResult()
@property
def benchmark(self) -> Gemma4BenchmarkResult:
return self._benchmark
def parse(self, response_text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Parse tool calls from model response using multiple strategies.
Returns list of tool call dicts in OpenAI format:
[{"id": "...", "type": "function", "function": {"name": "...", "arguments": "..."}}]
"""
t0 = time.monotonic()
self._benchmark.total_calls += 1
# Strategy 1: Native OpenAI format
result = self._try_native_parse(response_text)
if result:
self._record_attempt(response_text, True, result, "native")
self._benchmark.successful_parses += 1
if len(result) > 1:
self._benchmark.parallel_calls += 1
self._benchmark.strategies_used["native"] = self._benchmark.strategies_used.get("native", 0) + 1
self._update_timing(t0)
return result
# Strategy 2: JSON code blocks
result = self._try_json_block_parse(response_text, expected_tools)
if result:
self._record_attempt(response_text, True, result, "json_block")
self._benchmark.successful_parses += 1
if len(result) > 1:
self._benchmark.parallel_calls += 1
self._benchmark.strategies_used["json_block"] = self._benchmark.strategies_used.get("json_block", 0) + 1
self._update_timing(t0)
return result
# Strategy 3: Regex extraction
result = self._try_regex_parse(response_text)
if result:
self._record_attempt(response_text, True, result, "regex")
self._benchmark.successful_parses += 1
self._benchmark.strategies_used["regex"] = self._benchmark.strategies_used.get("regex", 0) + 1
self._update_timing(t0)
return result
# Strategy 4: Heuristic fallback
result = self._try_heuristic_parse(response_text, expected_tools)
if result:
self._record_attempt(response_text, True, result, "fallback")
self._benchmark.successful_parses += 1
self._benchmark.strategies_used["fallback"] = self._benchmark.strategies_used.get("fallback", 0) + 1
self._update_timing(t0)
return result
# All strategies failed
self._record_attempt(response_text, False, [], "none")
self._benchmark.errors.append(f"Failed to parse: {response_text[:200]}")
self._update_timing(t0)
return []
def _try_native_parse(self, text: str) -> List[Dict[str, Any]]:
"""Try parsing standard OpenAI tool_calls JSON."""
try:
data = json.loads(text)
if isinstance(data, dict) and "tool_calls" in data:
return data["tool_calls"]
if isinstance(data, list):
if all(isinstance(item, dict) and "function" in item for item in data):
return data
except json.JSONDecodeError:
pass
return []
def _try_json_block_parse(self, text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Extract tool calls from JSON code blocks."""
matches = self._JSON_BLOCK_PATTERN.findall(text)
calls = []
for match in matches:
try:
data = json.loads(match.strip())
if isinstance(data, dict):
if "name" in data and "arguments" in data:
calls.append(self._to_openai_format(data["name"], data["arguments"]))
elif "function" in data and "arguments" in data:
calls.append(self._to_openai_format(data["function"], data["arguments"]))
elif isinstance(data, list):
for item in data:
if isinstance(item, dict) and "name" in item:
args = item.get("arguments", item.get("args", {}))
calls.append(self._to_openai_format(item["name"], args))
except json.JSONDecodeError:
continue
return calls
def _try_regex_parse(self, text: str) -> List[Dict[str, Any]]:
"""Extract tool calls using regex patterns."""
calls = []
# Pattern: function_name({...})
for match in self._FUNCTION_CALL_PATTERN.finditer(text):
name = match.group(1)
args_str = match.group(2)
try:
args = json.loads(args_str)
calls.append(self._to_openai_format(name, args))
except json.JSONDecodeError:
continue
# Pattern: [tool_call] name: {...}
for match in self._GEMMA_INLINE_PATTERN.finditer(text):
name = match.group(1)
args_str = match.group(2)
try:
args = json.loads(args_str)
calls.append(self._to_openai_format(name, args))
except json.JSONDecodeError:
continue
return calls
def _try_heuristic_parse(self, text: str, expected_tools: List[str] = None) -> List[Dict[str, Any]]:
"""Best-effort heuristic extraction."""
if not expected_tools:
return []
calls = []
for tool_name in expected_tools:
# Look for tool name near JSON-like content
pattern = re.compile(
rf'{re.escape(tool_name)}\s*[\(:]\s*({{[^}}]+}})',
re.IGNORECASE,
)
match = pattern.search(text)
if match:
try:
args = json.loads(match.group(1))
calls.append(self._to_openai_format(tool_name, args))
except json.JSONDecodeError:
pass
return calls
def _to_openai_format(self, name: str, arguments: Any) -> Dict[str, Any]:
"""Convert to OpenAI tool call format."""
import uuid
args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
return {
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": name,
"arguments": args_str,
},
}
def _record_attempt(self, text: str, success: bool, result: list, strategy: str):
self._attempts.append(ToolCallAttempt(
raw_text=text[:500],
parsed=success,
tool_name=result[0]["function"]["name"] if result else "",
arguments={},
error="" if success else "parse failed",
strategy=strategy,
timestamp=time.time(),
))
def _update_timing(self, t0: float):
elapsed = (time.monotonic() - t0) * 1000
n = self._benchmark.total_calls
self._benchmark.avg_parse_time_ms = (
(self._benchmark.avg_parse_time_ms * (n - 1) + elapsed) / n
)
self._benchmark.success_rate = (
self._benchmark.successful_parses / n if n > 0 else 0
)
def format_report(self) -> str:
"""Format benchmark report."""
b = self._benchmark
lines = [
"Gemma 4 Tool Calling Benchmark",
"=" * 40,
f"Total attempts: {b.total_calls}",
f"Successful parses: {b.successful_parses}",
f"Success rate: {b.success_rate:.1%}",
f"Parallel calls: {b.parallel_calls}",
f"Avg parse time: {b.avg_parse_time_ms:.2f}ms",
"",
"Strategies used:",
]
for strategy, count in sorted(b.strategies_used.items(), key=lambda x: -x[1]):
lines.append(f" {strategy}: {count}")
if b.errors:
lines.append("")
lines.append(f"Errors ({len(b.errors)}):")
for err in b.errors[:5]:
lines.append(f" {err[:100]}")
return "\n".join(lines)

View File

@@ -1,94 +0,0 @@
"""Tests for Gemma 4 tool calling hardening."""
import json
import pytest
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.gemma4_tool_hardening import Gemma4ToolParser, Gemma4BenchmarkResult
class TestNativeParse:
def test_standard_tool_calls(self):
parser = Gemma4ToolParser()
text = json.dumps({"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file", "arguments": '{"path": "test.py"}'}}]})
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_list_format(self):
parser = Gemma4ToolParser()
text = json.dumps([{"id": "c1", "type": "function", "function": {"name": "terminal", "arguments": '{"command": "ls"}'}}])
result = parser.parse(text)
assert len(result) == 1
class TestJsonBlockParse:
def test_json_code_block(self):
parser = Gemma4ToolParser()
text = 'Here is the tool call:\n```json\n{"name": "read_file", "arguments": {"path": "test.py"}}\n```'
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_multiple_json_blocks(self):
parser = Gemma4ToolParser()
text = '```json\n{"name": "read_file", "arguments": {"path": "a.py"}}\n```\n```json\n{"name": "read_file", "arguments": {"path": "b.py"}}\n```'
result = parser.parse(text)
assert len(result) == 2
def test_list_in_json_block(self):
parser = Gemma4ToolParser()
text = '```json\n[{"name": "terminal", "arguments": {"command": "ls"}}]\n```'
result = parser.parse(text)
assert len(result) == 1
class TestRegexParse:
def test_function_call_pattern(self):
parser = Gemma4ToolParser()
text = 'I will call read_file({"path": "test.py"}) now.'
result = parser.parse(text)
assert len(result) == 1
assert result[0]["function"]["name"] == "read_file"
def test_gemma_inline_pattern(self):
parser = Gemma4ToolParser()
text = '[tool_call] terminal: {"command": "pwd"}'
result = parser.parse(text)
assert len(result) == 1
class TestHeuristicParse:
def test_heuristic_with_expected_tools(self):
parser = Gemma4ToolParser()
text = 'Calling read_file({"path": "config.yaml"}) now'
result = parser.parse(text, expected_tools=["read_file"])
assert len(result) == 1
def test_heuristic_without_expected_tools(self):
parser = Gemma4ToolParser()
text = 'Some text with {"key": "value"} but no tool name'
result = parser.parse(text)
assert len(result) == 0
class TestBenchmark:
def test_benchmark_counts(self):
parser = Gemma4ToolParser()
parser.parse(json.dumps({"tool_calls": [{"id": "1", "type": "function", "function": {"name": "x", "arguments": "{}"}}]}))
parser.parse('```json\n{"name": "y", "arguments": {}}\n```')
parser.parse('no tool call here')
b = parser.benchmark
assert b.total_calls == 3
assert b.successful_parses == 2
assert abs(b.success_rate - 2/3) < 0.01
def test_report_format(self):
parser = Gemma4ToolParser()
parser.parse(json.dumps({"tool_calls": [{"id": "1", "type": "function", "function": {"name": "x", "arguments": "{}"}}]}))
report = parser.format_report()
assert "Gemma 4 Tool Calling Benchmark" in report
assert "native" in report

313
tools/risk_scorer.py Normal file
View File

@@ -0,0 +1,313 @@
"""Context-Aware Risk Scoring — ML-lite tier detection enhancement.
Enhances the existing approval.py dangerous-command detection with
context-aware risk scoring. Instead of pure pattern matching, considers:
1. Path context: rm /tmp/x is safer than rm /etc/passwd
2. Command context: chmod 777 on project dir vs system dir
3. Recency: repeated dangerous commands increase risk
4. Scope: commands affecting more files = higher risk
Usage:
from tools.risk_scorer import score_action, RiskResult
result = score_action("rm -rf /tmp/build")
# result.tier = MEDIUM (not HIGH, because /tmp is safe)
# result.confidence = 0.7
"""
import os
import re
import time
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
# Risk tiers (aligned with approval_tiers.py)
class RiskTier(IntEnum):
SAFE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
CRITICAL = 4
@dataclass
class RiskResult:
"""Result of risk scoring."""
tier: RiskTier
confidence: float # 0.0 to 1.0
reasons: List[str] = field(default_factory=list)
context_factors: Dict[str, Any] = field(default_factory=dict)
# --- Path risk assessment ---
SAFE_PATHS = {
"/tmp", "/var/tmp", "/dev/shm",
"~/.hermes/sessions", "~/.hermes/cache", "~/.hermes/logs",
"/tmp/", "/var/tmp/",
}
HIGH_RISK_PATHS = {
"/etc", "/boot", "/usr/lib", "/usr/bin",
"~/.ssh", "~/.gnupg",
"/var/lib", "/opt",
}
CRITICAL_PATHS = {
"/", "/etc/passwd", "/etc/shadow", "/etc/sudoers",
"~/.ssh/id_rsa", "~/.ssh/authorized_keys",
"/boot/vmlinuz", "/dev/sda", "/dev/nvme",
}
def _extract_paths(command: str) -> List[str]:
"""Extract file paths from a command."""
paths = []
# Match common path patterns
for match in re.finditer(r'[/~][\w/.~-]+', command):
paths.append(match.group())
# Also match $HOME, $HERMES_HOME expansions
for match in re.finditer(r'\$(?:HOME|HERMES_HOME|PWD)[/\w]*', command):
paths.append(match.group())
return paths
def _classify_path(path: str) -> str:
"""Classify a path as safe, high-risk, or critical."""
path_lower = path.lower().replace("\\", "/")
for critical in CRITICAL_PATHS:
if path_lower.startswith(critical.lower()):
return "critical"
for high in HIGH_RISK_PATHS:
if path_lower.startswith(high.lower()):
return "high"
for safe in SAFE_PATHS:
if path_lower.startswith(safe.lower()):
return "safe"
# Unknown paths default to medium
return "unknown"
# --- Command risk modifiers ---
RISK_MODIFIERS = {
# Flags that increase risk
"-rf": 1.5,
"-r": 1.2,
"--force": 1.5,
"--recursive": 1.2,
"--no-preserve-root": 3.0,
"-f": 1.3,
"--hard": 1.5,
"--force-push": 2.0,
"-D": 1.4,
# Flags that decrease risk
"--dry-run": 0.1,
"-n": 0.3,
"--no-act": 0.1,
"--interactive": 0.7,
"-i": 0.7,
}
def _get_command_risk_modifier(command: str) -> float:
"""Get risk modifier based on command flags."""
modifier = 1.0
for flag, mod in RISK_MODIFIERS.items():
if flag in command:
modifier *= mod
return modifier
# --- Scope assessment ---
def _assess_scope(command: str) -> float:
"""Assess the scope of a command (how many files/systems affected)."""
scope = 1.0
# Wildcards increase scope
if "*" in command or "?" in command:
scope *= 2.0
# Recursive operations increase scope
if re.search(r'-r[f]?\b', command):
scope *= 1.5
# find/xargs pipelines increase scope
if "find" in command and ("exec" in command or "xargs" in command):
scope *= 2.0
# Multiple targets increase scope
paths = _extract_paths(command)
if len(paths) > 2:
scope *= 1.3
return min(scope, 5.0) # Cap at 5x
# --- Recent command tracking ---
_recent_commands: List[Tuple[float, str]] = []
_TRACK_WINDOW = 300 # 5 minutes
def _track_command(command: str) -> float:
"""Track command and return escalation factor based on recency."""
now = time.time()
# Clean old entries
global _recent_commands
_recent_commands = [
(ts, cmd) for ts, cmd in _recent_commands
if now - ts < _TRACK_WINDOW
]
# Check for repeated dangerous patterns
escalation = 1.0
for ts, recent_cmd in _recent_commands:
# Same command repeated = escalating risk
if recent_cmd == command:
escalation += 0.2
# Similar commands = moderate escalation
elif _commands_similar(command, recent_cmd):
escalation += 0.1
_recent_commands.append((now, command))
return min(escalation, 3.0) # Cap at 3x
def _commands_similar(cmd1: str, cmd2: str) -> bool:
"""Check if two commands are structurally similar."""
# Extract command name
name1 = cmd1.split()[0] if cmd1.split() else ""
name2 = cmd2.split()[0] if cmd2.split() else ""
return name1 == name2
# --- Main scoring function ---
# Base tier mapping from command name
COMMAND_BASE_TIERS = {
"rm": RiskTier.HIGH,
"chmod": RiskTier.MEDIUM,
"chown": RiskTier.HIGH,
"mkfs": RiskTier.CRITICAL,
"dd": RiskTier.HIGH,
"kill": RiskTier.HIGH,
"pkill": RiskTier.HIGH,
"systemctl": RiskTier.HIGH,
"git": RiskTier.LOW,
"sed": RiskTier.LOW,
"cp": RiskTier.LOW,
"mv": RiskTier.LOW,
"python3": RiskTier.LOW,
"pip": RiskTier.LOW,
"npm": RiskTier.LOW,
"docker": RiskTier.MEDIUM,
"ansible": RiskTier.HIGH,
}
def score_action(action: str, context: Optional[Dict[str, Any]] = None) -> RiskResult:
"""Score an action's risk level with context awareness.
Considers:
- Command base risk
- Path context (safe vs critical paths)
- Command flags (force, recursive, dry-run)
- Scope (wildcards, multiple targets)
- Recency (repeated commands escalate)
Returns:
RiskResult with tier, confidence, and reasons.
"""
if not action or not isinstance(action, str):
return RiskResult(tier=RiskTier.SAFE, confidence=1.0, reasons=["empty input"])
parts = action.strip().split()
if not parts:
return RiskResult(tier=RiskTier.SAFE, confidence=1.0, reasons=["empty command"])
cmd_name = parts[0].split("/")[-1] # Extract command name
# Base tier from command name
base_tier = COMMAND_BASE_TIERS.get(cmd_name, RiskTier.SAFE)
# Path risk assessment
paths = _extract_paths(action)
max_path_risk = "safe"
for path in paths:
path_risk = _classify_path(path)
risk_order = {"safe": 0, "unknown": 1, "high": 2, "critical": 3}
if risk_order.get(path_risk, 0) > risk_order.get(max_path_risk, 0):
max_path_risk = path_risk
# Calculate final tier
reasons = []
# Path-based tier adjustment
if max_path_risk == "critical":
base_tier = RiskTier.CRITICAL
reasons.append(f"Critical path detected: {paths[0] if paths else 'unknown'}")
elif max_path_risk == "high":
if base_tier.value < RiskTier.HIGH.value:
base_tier = RiskTier.HIGH
reasons.append(f"High-risk path: {paths[0] if paths else 'unknown'}")
elif max_path_risk == "safe":
# Downgrade if all paths are safe
if base_tier.value > RiskTier.MEDIUM.value:
base_tier = RiskTier.MEDIUM
reasons.append("Safe path context — risk downgraded")
# Apply modifiers
modifier = _get_command_risk_modifier(action)
scope = _assess_scope(action)
recency = _track_command(action)
# Check for dry-run (overrides everything)
if "--dry-run" in action or "-n " in action:
return RiskResult(
tier=RiskTier.SAFE,
confidence=0.95,
reasons=["dry-run mode — no actual changes"],
context_factors={"dry_run": True},
)
# Calculate confidence
confidence = 0.8 # Base confidence
if max_path_risk == "safe":
confidence = 0.9
elif max_path_risk == "unknown":
confidence = 0.6
elif max_path_risk == "critical":
confidence = 0.95
# Reasons
if modifier > 1.5:
reasons.append(f"Force/recursive flags (modifier: {modifier:.1f}x)")
if scope > 1.5:
reasons.append(f"Wide scope (wildcards/multiple targets, {scope:.1f}x)")
if recency > 1.2:
reasons.append(f"Repeated command pattern ({recency:.1f}x escalation)")
if not reasons:
reasons.append(f"Command '{cmd_name}' classified as {base_tier.name}")
return RiskResult(
tier=base_tier,
confidence=round(confidence, 2),
reasons=reasons,
context_factors={
"path_risk": max_path_risk,
"modifier": round(modifier, 2),
"scope": round(scope, 2),
"recency": round(recency, 2),
"paths": paths,
},
)