2026-02-23 02:11:33 -08:00
|
|
|
"""Shared interrupt signaling for all tools.
|
|
|
|
|
|
|
|
|
|
Provides a global threading.Event that any tool can check to determine
|
|
|
|
|
if the user has requested an interrupt. The agent's interrupt() method
|
|
|
|
|
sets this event, and tools poll it during long-running operations.
|
|
|
|
|
|
2026-03-30 23:47:04 +00:00
|
|
|
SECURITY FIX (V-007): Added proper locking to prevent race conditions
|
|
|
|
|
in interrupt propagation. Uses RLock for thread-safe nested access.
|
|
|
|
|
|
2026-02-23 02:11:33 -08:00
|
|
|
Usage in tools:
|
|
|
|
|
from tools.interrupt import is_interrupted
|
|
|
|
|
if is_interrupted():
|
|
|
|
|
return {"output": "[interrupted]", "returncode": 130}
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import threading
|
|
|
|
|
|
2026-03-30 23:47:04 +00:00
|
|
|
# Global interrupt event with proper synchronization
|
2026-02-23 02:11:33 -08:00
|
|
|
_interrupt_event = threading.Event()
|
2026-03-30 23:47:04 +00:00
|
|
|
_interrupt_lock = threading.RLock()
|
|
|
|
|
_interrupt_count = 0 # Track nested interrupts for idempotency
|
2026-02-23 02:11:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_interrupt(active: bool) -> None:
|
2026-03-30 23:47:04 +00:00
|
|
|
"""Called by the agent to signal or clear the interrupt.
|
|
|
|
|
|
|
|
|
|
SECURITY FIX: Uses RLock to prevent race conditions when multiple
|
|
|
|
|
threads attempt to set/clear the interrupt simultaneously.
|
|
|
|
|
"""
|
|
|
|
|
global _interrupt_count
|
|
|
|
|
|
|
|
|
|
with _interrupt_lock:
|
|
|
|
|
if active:
|
|
|
|
|
_interrupt_count += 1
|
|
|
|
|
_interrupt_event.set()
|
|
|
|
|
else:
|
|
|
|
|
_interrupt_count = 0
|
|
|
|
|
_interrupt_event.clear()
|
2026-02-23 02:11:33 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_interrupted() -> bool:
|
|
|
|
|
"""Check if an interrupt has been requested. Safe to call from any thread."""
|
|
|
|
|
return _interrupt_event.is_set()
|
2026-03-30 23:47:04 +00:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_interrupt_count() -> int:
|
|
|
|
|
"""Get the current interrupt nesting count (for debugging).
|
|
|
|
|
|
|
|
|
|
Returns the number of times set_interrupt(True) has been called
|
|
|
|
|
without a corresponding clear.
|
|
|
|
|
"""
|
|
|
|
|
with _interrupt_lock:
|
|
|
|
|
return _interrupt_count
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wait_for_interrupt(timeout: float = None) -> bool:
|
|
|
|
|
"""Block until interrupt is set or timeout expires.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
timeout: Maximum time to wait in seconds
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
True if interrupt was set, False if timeout expired
|
|
|
|
|
"""
|
|
|
|
|
return _interrupt_event.wait(timeout)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InterruptibleContext:
|
|
|
|
|
"""Context manager for interruptible operations.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
with InterruptibleContext() as ctx:
|
|
|
|
|
while ctx.should_continue():
|
|
|
|
|
do_work()
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, check_interval: int = 100):
|
|
|
|
|
self.check_interval = check_interval
|
|
|
|
|
self._iteration = 0
|
|
|
|
|
self._interrupted = False
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
def should_continue(self) -> bool:
|
|
|
|
|
"""Check if operation should continue (not interrupted)."""
|
|
|
|
|
self._iteration += 1
|
|
|
|
|
if self._iteration % self.check_interval == 0:
|
|
|
|
|
self._interrupted = is_interrupted()
|
|
|
|
|
return not self._interrupted
|