"""Shared interrupt signaling for all tools. Provides a global threading.Event that any tool can check to determine if the user has requested an interrupt. The agent's interrupt() method sets this event, and tools poll it during long-running operations. SECURITY FIX (V-007): Added proper locking to prevent race conditions in interrupt propagation. Uses RLock for thread-safe nested access. Usage in tools: from tools.interrupt import is_interrupted if is_interrupted(): return {"output": "[interrupted]", "returncode": 130} """ import threading # Global interrupt event with proper synchronization _interrupt_event = threading.Event() _interrupt_lock = threading.RLock() _interrupt_count = 0 # Track nested interrupts for idempotency def set_interrupt(active: bool) -> None: """Called by the agent to signal or clear the interrupt. SECURITY FIX: Uses RLock to prevent race conditions when multiple threads attempt to set/clear the interrupt simultaneously. """ global _interrupt_count with _interrupt_lock: if active: _interrupt_count += 1 _interrupt_event.set() else: _interrupt_count = 0 _interrupt_event.clear() def is_interrupted() -> bool: """Check if an interrupt has been requested. Safe to call from any thread.""" return _interrupt_event.is_set() def get_interrupt_count() -> int: """Get the current interrupt nesting count (for debugging). Returns the number of times set_interrupt(True) has been called without a corresponding clear. """ with _interrupt_lock: return _interrupt_count def wait_for_interrupt(timeout: float = None) -> bool: """Block until interrupt is set or timeout expires. Args: timeout: Maximum time to wait in seconds Returns: True if interrupt was set, False if timeout expired """ return _interrupt_event.wait(timeout) class InterruptibleContext: """Context manager for interruptible operations. Usage: with InterruptibleContext() as ctx: while ctx.should_continue(): do_work() """ def __init__(self, check_interval: int = 100): self.check_interval = check_interval self._iteration = 0 self._interrupted = False def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass def should_continue(self) -> bool: """Check if operation should continue (not interrupted).""" self._iteration += 1 if self._iteration % self.check_interval == 0: self._interrupted = is_interrupted() return not self._interrupted