Files
hermes-agent/tests/tools/test_interrupt.py
Allegro 13265971df
Some checks failed
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 29s
Docker Build and Publish / build-and-push (pull_request) Failing after 38s
Tests / test (pull_request) Failing after 28s
security: fix race condition in interrupt propagation (V-007)
Add proper RLock synchronization to prevent race conditions when multiple
threads access interrupt state simultaneously.

Changes:
- tools/interrupt.py: Add RLock, nesting count tracking, new APIs
- tools/terminal_tool.py: Remove direct _interrupt_event exposure
- tests/tools/test_interrupt.py: Comprehensive race condition tests

CVSS: 8.5 (High)
Refs: V-007, Issue #48
Fixes: CWE-362: Concurrent Execution using Shared Resource
2026-03-30 23:47:04 +00:00

180 lines
5.0 KiB
Python

"""Tests for interrupt handling and race condition fixes.
Validates V-007: Race Condition in Interrupt Propagation fixes.
"""
import threading
import time
import pytest
from tools.interrupt import (
set_interrupt,
is_interrupted,
get_interrupt_count,
wait_for_interrupt,
InterruptibleContext,
)
class TestInterruptBasics:
"""Test basic interrupt functionality."""
def test_interrupt_set_and_clear(self):
"""Test basic set/clear cycle."""
set_interrupt(True)
assert is_interrupted() is True
set_interrupt(False)
assert is_interrupted() is False
def test_interrupt_count(self):
"""Test interrupt nesting count."""
set_interrupt(False) # Reset
assert get_interrupt_count() == 0
set_interrupt(True)
assert get_interrupt_count() == 1
set_interrupt(True) # Nested
assert get_interrupt_count() == 2
set_interrupt(False) # Clear all
assert get_interrupt_count() == 0
assert is_interrupted() is False
class TestInterruptRaceConditions:
"""Test race condition fixes (V-007).
These tests validate that the RLock properly synchronizes
concurrent access to the interrupt state.
"""
def test_concurrent_set_interrupt(self):
"""Test concurrent set operations are thread-safe."""
set_interrupt(False) # Reset
results = []
errors = []
def setter_thread(thread_id):
try:
for _ in range(100):
set_interrupt(True)
time.sleep(0.001)
set_interrupt(False)
results.append(thread_id)
except Exception as e:
errors.append((thread_id, str(e)))
threads = [
threading.Thread(target=setter_thread, args=(i,))
for i in range(5)
]
for t in threads:
t.start()
for t in threads:
t.join(timeout=10)
assert len(errors) == 0, f"Thread errors: {errors}"
assert len(results) == 5
def test_concurrent_read_write(self):
"""Test concurrent reads and writes are consistent."""
set_interrupt(False)
read_results = []
write_done = threading.Event()
def reader():
while not write_done.is_set():
_ = is_interrupted()
_ = get_interrupt_count()
def writer():
for _ in range(500):
set_interrupt(True)
set_interrupt(False)
write_done.set()
readers = [threading.Thread(target=reader) for _ in range(3)]
writer_t = threading.Thread(target=writer)
for r in readers:
r.start()
writer_t.start()
writer_t.join(timeout=15)
write_done.set()
for r in readers:
r.join(timeout=5)
# No assertion needed - test passes if no exceptions/deadlocks
class TestInterruptibleContext:
"""Test InterruptibleContext helper."""
def test_context_manager(self):
"""Test context manager basic usage."""
set_interrupt(False)
with InterruptibleContext() as ctx:
for _ in range(10):
assert ctx.should_continue() is True
assert is_interrupted() is False
def test_context_respects_interrupt(self):
"""Test that context stops on interrupt."""
set_interrupt(False)
with InterruptibleContext(check_interval=5) as ctx:
# Simulate work
for i in range(20):
if i == 10:
set_interrupt(True)
if not ctx.should_continue():
break
# Should have been interrupted
assert is_interrupted() is True
set_interrupt(False) # Cleanup
class TestWaitForInterrupt:
"""Test wait_for_interrupt functionality."""
def test_wait_with_timeout(self):
"""Test wait returns False on timeout."""
set_interrupt(False)
start = time.time()
result = wait_for_interrupt(timeout=0.1)
elapsed = time.time() - start
assert result is False
assert elapsed < 0.5 # Should not hang
def test_wait_interruptible(self):
"""Test wait returns True when interrupted."""
set_interrupt(False)
def delayed_interrupt():
time.sleep(0.1)
set_interrupt(True)
t = threading.Thread(target=delayed_interrupt)
t.start()
start = time.time()
result = wait_for_interrupt(timeout=5.0)
elapsed = time.time() - start
t.join(timeout=5)
assert result is True
assert elapsed < 1.0 # Should return quickly after interrupt
set_interrupt(False) # Cleanup