Add proper RLock synchronization to prevent race conditions when multiple threads access interrupt state simultaneously. Changes: - tools/interrupt.py: Add RLock, nesting count tracking, new APIs - tools/terminal_tool.py: Remove direct _interrupt_event exposure - tests/tools/test_interrupt.py: Comprehensive race condition tests CVSS: 8.5 (High) Refs: V-007, Issue #48 Fixes: CWE-362: Concurrent Execution using Shared Resource
180 lines
5.0 KiB
Python
180 lines
5.0 KiB
Python
"""Tests for interrupt handling and race condition fixes.
|
|
|
|
Validates V-007: Race Condition in Interrupt Propagation fixes.
|
|
"""
|
|
|
|
import threading
|
|
import time
|
|
import pytest
|
|
from tools.interrupt import (
|
|
set_interrupt,
|
|
is_interrupted,
|
|
get_interrupt_count,
|
|
wait_for_interrupt,
|
|
InterruptibleContext,
|
|
)
|
|
|
|
|
|
class TestInterruptBasics:
|
|
"""Test basic interrupt functionality."""
|
|
|
|
def test_interrupt_set_and_clear(self):
|
|
"""Test basic set/clear cycle."""
|
|
set_interrupt(True)
|
|
assert is_interrupted() is True
|
|
|
|
set_interrupt(False)
|
|
assert is_interrupted() is False
|
|
|
|
def test_interrupt_count(self):
|
|
"""Test interrupt nesting count."""
|
|
set_interrupt(False) # Reset
|
|
assert get_interrupt_count() == 0
|
|
|
|
set_interrupt(True)
|
|
assert get_interrupt_count() == 1
|
|
|
|
set_interrupt(True) # Nested
|
|
assert get_interrupt_count() == 2
|
|
|
|
set_interrupt(False) # Clear all
|
|
assert get_interrupt_count() == 0
|
|
assert is_interrupted() is False
|
|
|
|
|
|
class TestInterruptRaceConditions:
|
|
"""Test race condition fixes (V-007).
|
|
|
|
These tests validate that the RLock properly synchronizes
|
|
concurrent access to the interrupt state.
|
|
"""
|
|
|
|
def test_concurrent_set_interrupt(self):
|
|
"""Test concurrent set operations are thread-safe."""
|
|
set_interrupt(False) # Reset
|
|
|
|
results = []
|
|
errors = []
|
|
|
|
def setter_thread(thread_id):
|
|
try:
|
|
for _ in range(100):
|
|
set_interrupt(True)
|
|
time.sleep(0.001)
|
|
set_interrupt(False)
|
|
results.append(thread_id)
|
|
except Exception as e:
|
|
errors.append((thread_id, str(e)))
|
|
|
|
threads = [
|
|
threading.Thread(target=setter_thread, args=(i,))
|
|
for i in range(5)
|
|
]
|
|
|
|
for t in threads:
|
|
t.start()
|
|
for t in threads:
|
|
t.join(timeout=10)
|
|
|
|
assert len(errors) == 0, f"Thread errors: {errors}"
|
|
assert len(results) == 5
|
|
|
|
def test_concurrent_read_write(self):
|
|
"""Test concurrent reads and writes are consistent."""
|
|
set_interrupt(False)
|
|
|
|
read_results = []
|
|
write_done = threading.Event()
|
|
|
|
def reader():
|
|
while not write_done.is_set():
|
|
_ = is_interrupted()
|
|
_ = get_interrupt_count()
|
|
|
|
def writer():
|
|
for _ in range(500):
|
|
set_interrupt(True)
|
|
set_interrupt(False)
|
|
write_done.set()
|
|
|
|
readers = [threading.Thread(target=reader) for _ in range(3)]
|
|
writer_t = threading.Thread(target=writer)
|
|
|
|
for r in readers:
|
|
r.start()
|
|
writer_t.start()
|
|
|
|
writer_t.join(timeout=15)
|
|
write_done.set()
|
|
for r in readers:
|
|
r.join(timeout=5)
|
|
|
|
# No assertion needed - test passes if no exceptions/deadlocks
|
|
|
|
|
|
class TestInterruptibleContext:
|
|
"""Test InterruptibleContext helper."""
|
|
|
|
def test_context_manager(self):
|
|
"""Test context manager basic usage."""
|
|
set_interrupt(False)
|
|
|
|
with InterruptibleContext() as ctx:
|
|
for _ in range(10):
|
|
assert ctx.should_continue() is True
|
|
|
|
assert is_interrupted() is False
|
|
|
|
def test_context_respects_interrupt(self):
|
|
"""Test that context stops on interrupt."""
|
|
set_interrupt(False)
|
|
|
|
with InterruptibleContext(check_interval=5) as ctx:
|
|
# Simulate work
|
|
for i in range(20):
|
|
if i == 10:
|
|
set_interrupt(True)
|
|
if not ctx.should_continue():
|
|
break
|
|
|
|
# Should have been interrupted
|
|
assert is_interrupted() is True
|
|
set_interrupt(False) # Cleanup
|
|
|
|
|
|
class TestWaitForInterrupt:
|
|
"""Test wait_for_interrupt functionality."""
|
|
|
|
def test_wait_with_timeout(self):
|
|
"""Test wait returns False on timeout."""
|
|
set_interrupt(False)
|
|
|
|
start = time.time()
|
|
result = wait_for_interrupt(timeout=0.1)
|
|
elapsed = time.time() - start
|
|
|
|
assert result is False
|
|
assert elapsed < 0.5 # Should not hang
|
|
|
|
def test_wait_interruptible(self):
|
|
"""Test wait returns True when interrupted."""
|
|
set_interrupt(False)
|
|
|
|
def delayed_interrupt():
|
|
time.sleep(0.1)
|
|
set_interrupt(True)
|
|
|
|
t = threading.Thread(target=delayed_interrupt)
|
|
t.start()
|
|
|
|
start = time.time()
|
|
result = wait_for_interrupt(timeout=5.0)
|
|
elapsed = time.time() - start
|
|
|
|
t.join(timeout=5)
|
|
|
|
assert result is True
|
|
assert elapsed < 1.0 # Should return quickly after interrupt
|
|
|
|
set_interrupt(False) # Cleanup
|