diff --git a/tests/tools/test_approval.py b/tests/tools/test_approval.py index a36b2e1f8..0cb4c3571 100644 --- a/tests/tools/test_approval.py +++ b/tests/tools/test_approval.py @@ -2,12 +2,14 @@ from unittest.mock import patch as mock_patch +import tools.approval as approval_module from tools.approval import ( approve_session, clear_session, detect_dangerous_command, has_pending, is_approved, + load_permanent, pop_pending, prompt_dangerous_approval, submit_pending, @@ -368,6 +370,20 @@ class TestPatternKeyUniqueness: ) clear_session(session) + def test_legacy_find_key_still_approves_find_exec(self): + """Old allowlist entry 'find' should keep approving the matching command.""" + _, key_exec, _ = detect_dangerous_command("find . -exec rm {} \\;") + with mock_patch.object(approval_module, "_permanent_approved", set()): + load_permanent({"find"}) + assert is_approved("legacy-find", key_exec) is True + + def test_legacy_find_key_still_approves_find_delete(self): + """Old colliding allowlist entry 'find' should remain backwards compatible.""" + _, key_delete, _ = detect_dangerous_command("find . -name '*.tmp' -delete") + with mock_patch.object(approval_module, "_permanent_approved", set()): + load_permanent({"find"}) + assert is_approved("legacy-find", key_delete) is True + class TestViewFullCommand: """Tests for the 'view full command' option in prompt_dangerous_approval.""" diff --git a/tools/approval.py b/tools/approval.py index 21baedbd0..7c376f0e9 100644 --- a/tools/approval.py +++ b/tools/approval.py @@ -50,6 +50,29 @@ DANGEROUS_PATTERNS = [ ] +def _legacy_pattern_key(pattern: str) -> str: + """Reproduce the old regex-derived approval key for backwards compatibility.""" + return pattern.split(r'\b')[1] if r'\b' in pattern else pattern[:20] + + +_PATTERN_KEY_ALIASES: dict[str, set[str]] = {} +for _pattern, _description in DANGEROUS_PATTERNS: + _legacy_key = _legacy_pattern_key(_pattern) + _canonical_key = _description + _PATTERN_KEY_ALIASES.setdefault(_canonical_key, set()).update({_canonical_key, _legacy_key}) + _PATTERN_KEY_ALIASES.setdefault(_legacy_key, set()).update({_legacy_key, _canonical_key}) + + +def _approval_key_aliases(pattern_key: str) -> set[str]: + """Return all approval keys that should match this pattern. + + New approvals use the human-readable description string, but older + command_allowlist entries and session approvals may still contain the + historical regex-derived key. + """ + return _PATTERN_KEY_ALIASES.get(pattern_key, {pattern_key}) + + # ========================================================================= # Detection # ========================================================================= @@ -103,11 +126,17 @@ def approve_session(session_key: str, pattern_key: str): def is_approved(session_key: str, pattern_key: str) -> bool: - """Check if a pattern is approved (session-scoped or permanent).""" + """Check if a pattern is approved (session-scoped or permanent). + + Accept both the current canonical key and the legacy regex-derived key so + existing command_allowlist entries continue to work after key migrations. + """ + aliases = _approval_key_aliases(pattern_key) with _lock: - if pattern_key in _permanent_approved: + if any(alias in _permanent_approved for alias in aliases): return True - return pattern_key in _session_approved.get(session_key, set()) + session_approvals = _session_approved.get(session_key, set()) + return any(alias in session_approvals for alias in aliases) def approve_permanent(pattern_key: str):