"""Tests for gateway/stream_consumer.py - Stream consumption and backpressure. Tests message streaming, backpressure handling, and reconnection logic. """ import pytest import asyncio from unittest.mock import patch, MagicMock, AsyncMock from types import SimpleNamespace try: from gateway.stream_consumer import ( StreamConsumer, BackpressureStrategy, MessageBuffer, ReconnectPolicy, StreamError, ) HAS_MODULE = True except ImportError: HAS_MODULE = False pytestmark = [ pytest.mark.skipif(not HAS_MODULE, reason="stream_consumer module not found"), pytest.mark.asyncio, ] class TestMessageBuffer: """Tests for message buffering.""" async def test_buffer_basic_operations(self): """Should support basic put/get operations.""" buffer = MessageBuffer(max_size=100) await buffer.put("message1") await buffer.put("message2") assert buffer.size() == 2 msg1 = await buffer.get() msg2 = await buffer.get() assert msg1 == "message1" assert msg2 == "message2" async def test_buffer_respects_max_size(self): """Should block put when buffer is full.""" buffer = MessageBuffer(max_size=2) await buffer.put("msg1") await buffer.put("msg2") # Third put should block with pytest.raises(asyncio.TimeoutError): await asyncio.wait_for(buffer.put("msg3"), timeout=0.1) async def test_buffer_clear(self): """Should clear all messages.""" buffer = MessageBuffer(max_size=100) await buffer.put("msg1") await buffer.put("msg2") buffer.clear() assert buffer.size() == 0 async def test_buffer_peek(self): """Should peek at next message without removing.""" buffer = MessageBuffer(max_size=100) await buffer.put("msg1") peeked = buffer.peek() assert peeked == "msg1" assert buffer.size() == 1 # Not removed class TestBackpressureStrategies: """Tests for backpressure handling strategies.""" async def test_drop_oldest_strategy(self): """Should drop oldest messages when buffer full.""" strategy = BackpressureStrategy.DROP_OLDEST buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy) await buffer.put("old1") await buffer.put("old2") await buffer.put("old3") # Add new message - should drop oldest await buffer.put_with_backpressure("new") assert buffer.size() == 3 assert "old1" not in list(buffer.items()) assert "new" in list(buffer.items()) async def test_drop_newest_strategy(self): """Should drop newest messages when buffer full.""" strategy = BackpressureStrategy.DROP_NEWEST buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy) await buffer.put("msg1") await buffer.put("msg2") await buffer.put("msg3") # Try to add new message - should be dropped result = await buffer.put_with_backpressure("new") assert buffer.size() == 3 assert "new" not in list(buffer.items()) assert result is False # Indicate message was dropped async def test_block_strategy(self): """Should block producer when buffer full.""" strategy = BackpressureStrategy.BLOCK buffer = MessageBuffer(max_size=2, backpressure_strategy=strategy) await buffer.put("msg1") await buffer.put("msg2") # Start put in background put_task = asyncio.create_task(buffer.put_with_backpressure("msg3")) # Should be blocked await asyncio.sleep(0.05) assert not put_task.done() # Remove item - should unblock await buffer.get() await asyncio.wait_for(put_task, timeout=0.1) assert buffer.size() == 2 class TestStreamConsumer: """Tests for stream consumer functionality.""" async def test_consumer_start_stop(self): """Should start and stop cleanly.""" consumer = StreamConsumer( endpoint="ws://test.example.com/stream", message_handler=AsyncMock() ) with patch.object(consumer, '_connect', new_callable=AsyncMock): await consumer.start() assert consumer.is_running await consumer.stop() assert not consumer.is_running async def test_message_handler_invocation(self): """Should invoke message handler for each message.""" handler = AsyncMock() consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=handler ) test_message = {"id": "1", "content": "test"} await consumer._process_message(test_message) handler.assert_called_once_with(test_message) async def test_message_batching(self): """Should batch messages when batch_size configured.""" handler = AsyncMock() consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=handler, batch_size=3, batch_timeout=1.0 ) # Add messages await consumer._buffer.put({"id": "1"}) await consumer._buffer.put({"id": "2"}) # Should not invoke handler yet handler.assert_not_called() # Add third message - should trigger batch await consumer._buffer.put({"id": "3"}) await consumer._flush_batch() handler.assert_called_once() assert len(handler.call_args[0][0]) == 3 async def test_error_handling(self): """Should handle handler errors gracefully.""" handler = AsyncMock(side_effect=Exception("Handler error")) error_callback = AsyncMock() consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=handler, error_handler=error_callback ) await consumer._process_message({"id": "1"}) error_callback.assert_called_once() assert consumer.is_running # Should continue running class TestReconnectPolicy: """Tests for reconnection logic.""" def test_exponential_backoff(self): """Should use exponential backoff for retries.""" policy = ReconnectPolicy( max_retries=5, base_delay=1.0, max_delay=30.0, exponential_base=2.0 ) delays = [policy.get_delay(attempt) for attempt in range(5)] assert delays[0] == 1.0 assert delays[1] == 2.0 assert delays[2] == 4.0 assert delays[3] == 8.0 assert delays[4] == 16.0 # Capped below max_delay def test_max_delay_cap(self): """Should cap delay at max_delay.""" policy = ReconnectPolicy( max_retries=10, base_delay=1.0, max_delay=5.0, exponential_base=2.0 ) delay = policy.get_delay(attempt=10) assert delay <= 5.0 def test_jitter_addition(self): """Should add jitter to prevent thundering herd.""" policy = ReconnectPolicy( max_retries=5, base_delay=1.0, jitter=True, jitter_range=(0.0, 0.5) ) delays = [policy.get_delay(0) for _ in range(10)] # All delays should be different (with high probability) assert len(set(delays)) > 1 # All should be within expected range assert all(1.0 <= d <= 1.5 for d in delays) def test_retry_exhaustion(self): """Should indicate when retries exhausted.""" policy = ReconnectPolicy(max_retries=3) assert policy.should_retry(0) is True assert policy.should_retry(1) is True assert policy.should_retry(2) is True assert policy.should_retry(3) is False assert policy.should_retry(4) is False class TestStreamConsumerReconnect: """Tests for consumer reconnection behavior.""" async def test_reconnect_on_connection_error(self): """Should reconnect on connection error.""" connect_mock = AsyncMock(side_effect=[ Exception("Connection failed"), MagicMock(), # Success on second try ]) consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=AsyncMock(), reconnect_policy=ReconnectPolicy(max_retries=3, base_delay=0.1) ) with patch.object(consumer, '_connect', connect_mock): await consumer.start() # Simulate connection error await consumer._handle_connection_error() # Should have attempted reconnect assert connect_mock.call_count >= 2 async def test_message_ordering_after_reconnect(self): """Should maintain message ordering after reconnect.""" received_messages = [] async def handler(msg): received_messages.append(msg["seq"]) consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=handler ) # Simulate messages arriving during reconnection await consumer._buffer.put({"seq": 1}) await consumer._buffer.put({"seq": 2}) # Process all while consumer._buffer.size() > 0: await consumer._process_one() assert received_messages == [1, 2] async def test_graceful_shutdown_during_reconnect(self): """Should shutdown gracefully even during reconnection.""" consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=AsyncMock(), reconnect_policy=ReconnectPolicy(max_retries=100, base_delay=1.0) ) # Start reconnect loop reconnect_task = asyncio.create_task(consumer._reconnect_loop()) await asyncio.sleep(0.05) # Stop should cancel reconnect await consumer.stop() assert reconnect_task.done() assert not consumer.is_running class TestStreamConsumerMetrics: """Tests for consumer metrics and observability.""" async def test_message_count_tracking(self): """Should track message counts.""" consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=AsyncMock() ) await consumer._process_message({"id": "1"}) await consumer._process_message({"id": "2"}) await consumer._process_message({"id": "3"}) assert consumer.metrics.messages_received == 3 async def test_error_count_tracking(self): """Should track error counts.""" handler = AsyncMock(side_effect=Exception("Error")) consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=handler ) await consumer._process_message({"id": "1"}) await consumer._process_message({"id": "2"}) assert consumer.metrics.errors == 2 async def test_latency_tracking(self): """Should track processing latency.""" async def slow_handler(msg): await asyncio.sleep(0.05) consumer = StreamConsumer( endpoint="ws://test.example.com", message_handler=slow_handler ) await consumer._process_message({"id": "1"}) assert consumer.metrics.avg_latency_ms >= 50