"""Tests for interrupt handling and race condition fixes. Validates V-007: Race Condition in Interrupt Propagation fixes. """ import threading import time import pytest from tools.interrupt import ( set_interrupt, is_interrupted, get_interrupt_count, wait_for_interrupt, InterruptibleContext, ) class TestInterruptBasics: """Test basic interrupt functionality.""" def test_interrupt_set_and_clear(self): """Test basic set/clear cycle.""" set_interrupt(True) assert is_interrupted() is True set_interrupt(False) assert is_interrupted() is False def test_interrupt_count(self): """Test interrupt nesting count.""" set_interrupt(False) # Reset assert get_interrupt_count() == 0 set_interrupt(True) assert get_interrupt_count() == 1 set_interrupt(True) # Nested assert get_interrupt_count() == 2 set_interrupt(False) # Clear all assert get_interrupt_count() == 0 assert is_interrupted() is False class TestInterruptRaceConditions: """Test race condition fixes (V-007). These tests validate that the RLock properly synchronizes concurrent access to the interrupt state. """ def test_concurrent_set_interrupt(self): """Test concurrent set operations are thread-safe.""" set_interrupt(False) # Reset results = [] errors = [] def setter_thread(thread_id): try: for _ in range(100): set_interrupt(True) time.sleep(0.001) set_interrupt(False) results.append(thread_id) except Exception as e: errors.append((thread_id, str(e))) threads = [ threading.Thread(target=setter_thread, args=(i,)) for i in range(5) ] for t in threads: t.start() for t in threads: t.join(timeout=10) assert len(errors) == 0, f"Thread errors: {errors}" assert len(results) == 5 def test_concurrent_read_write(self): """Test concurrent reads and writes are consistent.""" set_interrupt(False) read_results = [] write_done = threading.Event() def reader(): while not write_done.is_set(): _ = is_interrupted() _ = get_interrupt_count() def writer(): for _ in range(500): set_interrupt(True) set_interrupt(False) write_done.set() readers = [threading.Thread(target=reader) for _ in range(3)] writer_t = threading.Thread(target=writer) for r in readers: r.start() writer_t.start() writer_t.join(timeout=15) write_done.set() for r in readers: r.join(timeout=5) # No assertion needed - test passes if no exceptions/deadlocks class TestInterruptibleContext: """Test InterruptibleContext helper.""" def test_context_manager(self): """Test context manager basic usage.""" set_interrupt(False) with InterruptibleContext() as ctx: for _ in range(10): assert ctx.should_continue() is True assert is_interrupted() is False def test_context_respects_interrupt(self): """Test that context stops on interrupt.""" set_interrupt(False) with InterruptibleContext(check_interval=5) as ctx: # Simulate work for i in range(20): if i == 10: set_interrupt(True) if not ctx.should_continue(): break # Should have been interrupted assert is_interrupted() is True set_interrupt(False) # Cleanup class TestWaitForInterrupt: """Test wait_for_interrupt functionality.""" def test_wait_with_timeout(self): """Test wait returns False on timeout.""" set_interrupt(False) start = time.time() result = wait_for_interrupt(timeout=0.1) elapsed = time.time() - start assert result is False assert elapsed < 0.5 # Should not hang def test_wait_interruptible(self): """Test wait returns True when interrupted.""" set_interrupt(False) def delayed_interrupt(): time.sleep(0.1) set_interrupt(True) t = threading.Thread(target=delayed_interrupt) t.start() start = time.time() result = wait_for_interrupt(timeout=5.0) elapsed = time.time() - start t.join(timeout=5) assert result is True assert elapsed < 1.0 # Should return quickly after interrupt set_interrupt(False) # Cleanup