Replace shell=True with list-based subprocess execution to prevent command injection via malicious user input. Changes: - tools/transcription_tools.py: Use shlex.split() + shell=False - tools/environments/docker.py: List-based commands with container ID validation Fixes CVE-level vulnerability where malicious file paths or container IDs could inject arbitrary commands. CVSS: 9.8 (Critical) Refs: V-001 in SECURITY_AUDIT_REPORT.md
375 lines
12 KiB
Python
375 lines
12 KiB
Python
"""Tests for gateway/stream_consumer.py - Stream consumption and backpressure.
|
|
|
|
Tests message streaming, backpressure handling, and reconnection logic.
|
|
"""
|
|
|
|
import pytest
|
|
import asyncio
|
|
from unittest.mock import patch, MagicMock, AsyncMock
|
|
from types import SimpleNamespace
|
|
|
|
try:
|
|
from gateway.stream_consumer import (
|
|
StreamConsumer,
|
|
BackpressureStrategy,
|
|
MessageBuffer,
|
|
ReconnectPolicy,
|
|
StreamError,
|
|
)
|
|
HAS_MODULE = True
|
|
except ImportError:
|
|
HAS_MODULE = False
|
|
|
|
|
|
pytestmark = [
|
|
pytest.mark.skipif(not HAS_MODULE, reason="stream_consumer module not found"),
|
|
pytest.mark.asyncio,
|
|
]
|
|
|
|
|
|
class TestMessageBuffer:
|
|
"""Tests for message buffering."""
|
|
|
|
async def test_buffer_basic_operations(self):
|
|
"""Should support basic put/get operations."""
|
|
buffer = MessageBuffer(max_size=100)
|
|
|
|
await buffer.put("message1")
|
|
await buffer.put("message2")
|
|
|
|
assert buffer.size() == 2
|
|
|
|
msg1 = await buffer.get()
|
|
msg2 = await buffer.get()
|
|
|
|
assert msg1 == "message1"
|
|
assert msg2 == "message2"
|
|
|
|
async def test_buffer_respects_max_size(self):
|
|
"""Should block put when buffer is full."""
|
|
buffer = MessageBuffer(max_size=2)
|
|
|
|
await buffer.put("msg1")
|
|
await buffer.put("msg2")
|
|
|
|
# Third put should block
|
|
with pytest.raises(asyncio.TimeoutError):
|
|
await asyncio.wait_for(buffer.put("msg3"), timeout=0.1)
|
|
|
|
async def test_buffer_clear(self):
|
|
"""Should clear all messages."""
|
|
buffer = MessageBuffer(max_size=100)
|
|
|
|
await buffer.put("msg1")
|
|
await buffer.put("msg2")
|
|
buffer.clear()
|
|
|
|
assert buffer.size() == 0
|
|
|
|
async def test_buffer_peek(self):
|
|
"""Should peek at next message without removing."""
|
|
buffer = MessageBuffer(max_size=100)
|
|
|
|
await buffer.put("msg1")
|
|
|
|
peeked = buffer.peek()
|
|
assert peeked == "msg1"
|
|
assert buffer.size() == 1 # Not removed
|
|
|
|
|
|
class TestBackpressureStrategies:
|
|
"""Tests for backpressure handling strategies."""
|
|
|
|
async def test_drop_oldest_strategy(self):
|
|
"""Should drop oldest messages when buffer full."""
|
|
strategy = BackpressureStrategy.DROP_OLDEST
|
|
buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy)
|
|
|
|
await buffer.put("old1")
|
|
await buffer.put("old2")
|
|
await buffer.put("old3")
|
|
|
|
# Add new message - should drop oldest
|
|
await buffer.put_with_backpressure("new")
|
|
|
|
assert buffer.size() == 3
|
|
assert "old1" not in list(buffer.items())
|
|
assert "new" in list(buffer.items())
|
|
|
|
async def test_drop_newest_strategy(self):
|
|
"""Should drop newest messages when buffer full."""
|
|
strategy = BackpressureStrategy.DROP_NEWEST
|
|
buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy)
|
|
|
|
await buffer.put("msg1")
|
|
await buffer.put("msg2")
|
|
await buffer.put("msg3")
|
|
|
|
# Try to add new message - should be dropped
|
|
result = await buffer.put_with_backpressure("new")
|
|
|
|
assert buffer.size() == 3
|
|
assert "new" not in list(buffer.items())
|
|
assert result is False # Indicate message was dropped
|
|
|
|
async def test_block_strategy(self):
|
|
"""Should block producer when buffer full."""
|
|
strategy = BackpressureStrategy.BLOCK
|
|
buffer = MessageBuffer(max_size=2, backpressure_strategy=strategy)
|
|
|
|
await buffer.put("msg1")
|
|
await buffer.put("msg2")
|
|
|
|
# Start put in background
|
|
put_task = asyncio.create_task(buffer.put_with_backpressure("msg3"))
|
|
|
|
# Should be blocked
|
|
await asyncio.sleep(0.05)
|
|
assert not put_task.done()
|
|
|
|
# Remove item - should unblock
|
|
await buffer.get()
|
|
await asyncio.wait_for(put_task, timeout=0.1)
|
|
|
|
assert buffer.size() == 2
|
|
|
|
|
|
class TestStreamConsumer:
|
|
"""Tests for stream consumer functionality."""
|
|
|
|
async def test_consumer_start_stop(self):
|
|
"""Should start and stop cleanly."""
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com/stream",
|
|
message_handler=AsyncMock()
|
|
)
|
|
|
|
with patch.object(consumer, '_connect', new_callable=AsyncMock):
|
|
await consumer.start()
|
|
assert consumer.is_running
|
|
|
|
await consumer.stop()
|
|
assert not consumer.is_running
|
|
|
|
async def test_message_handler_invocation(self):
|
|
"""Should invoke message handler for each message."""
|
|
handler = AsyncMock()
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=handler
|
|
)
|
|
|
|
test_message = {"id": "1", "content": "test"}
|
|
await consumer._process_message(test_message)
|
|
|
|
handler.assert_called_once_with(test_message)
|
|
|
|
async def test_message_batching(self):
|
|
"""Should batch messages when batch_size configured."""
|
|
handler = AsyncMock()
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=handler,
|
|
batch_size=3,
|
|
batch_timeout=1.0
|
|
)
|
|
|
|
# Add messages
|
|
await consumer._buffer.put({"id": "1"})
|
|
await consumer._buffer.put({"id": "2"})
|
|
|
|
# Should not invoke handler yet
|
|
handler.assert_not_called()
|
|
|
|
# Add third message - should trigger batch
|
|
await consumer._buffer.put({"id": "3"})
|
|
await consumer._flush_batch()
|
|
|
|
handler.assert_called_once()
|
|
assert len(handler.call_args[0][0]) == 3
|
|
|
|
async def test_error_handling(self):
|
|
"""Should handle handler errors gracefully."""
|
|
handler = AsyncMock(side_effect=Exception("Handler error"))
|
|
error_callback = AsyncMock()
|
|
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=handler,
|
|
error_handler=error_callback
|
|
)
|
|
|
|
await consumer._process_message({"id": "1"})
|
|
|
|
error_callback.assert_called_once()
|
|
assert consumer.is_running # Should continue running
|
|
|
|
|
|
class TestReconnectPolicy:
|
|
"""Tests for reconnection logic."""
|
|
|
|
def test_exponential_backoff(self):
|
|
"""Should use exponential backoff for retries."""
|
|
policy = ReconnectPolicy(
|
|
max_retries=5,
|
|
base_delay=1.0,
|
|
max_delay=30.0,
|
|
exponential_base=2.0
|
|
)
|
|
|
|
delays = [policy.get_delay(attempt) for attempt in range(5)]
|
|
|
|
assert delays[0] == 1.0
|
|
assert delays[1] == 2.0
|
|
assert delays[2] == 4.0
|
|
assert delays[3] == 8.0
|
|
assert delays[4] == 16.0 # Capped below max_delay
|
|
|
|
def test_max_delay_cap(self):
|
|
"""Should cap delay at max_delay."""
|
|
policy = ReconnectPolicy(
|
|
max_retries=10,
|
|
base_delay=1.0,
|
|
max_delay=5.0,
|
|
exponential_base=2.0
|
|
)
|
|
|
|
delay = policy.get_delay(attempt=10)
|
|
assert delay <= 5.0
|
|
|
|
def test_jitter_addition(self):
|
|
"""Should add jitter to prevent thundering herd."""
|
|
policy = ReconnectPolicy(
|
|
max_retries=5,
|
|
base_delay=1.0,
|
|
jitter=True,
|
|
jitter_range=(0.0, 0.5)
|
|
)
|
|
|
|
delays = [policy.get_delay(0) for _ in range(10)]
|
|
|
|
# All delays should be different (with high probability)
|
|
assert len(set(delays)) > 1
|
|
# All should be within expected range
|
|
assert all(1.0 <= d <= 1.5 for d in delays)
|
|
|
|
def test_retry_exhaustion(self):
|
|
"""Should indicate when retries exhausted."""
|
|
policy = ReconnectPolicy(max_retries=3)
|
|
|
|
assert policy.should_retry(0) is True
|
|
assert policy.should_retry(1) is True
|
|
assert policy.should_retry(2) is True
|
|
assert policy.should_retry(3) is False
|
|
assert policy.should_retry(4) is False
|
|
|
|
|
|
class TestStreamConsumerReconnect:
|
|
"""Tests for consumer reconnection behavior."""
|
|
|
|
async def test_reconnect_on_connection_error(self):
|
|
"""Should reconnect on connection error."""
|
|
connect_mock = AsyncMock(side_effect=[
|
|
Exception("Connection failed"),
|
|
MagicMock(), # Success on second try
|
|
])
|
|
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=AsyncMock(),
|
|
reconnect_policy=ReconnectPolicy(max_retries=3, base_delay=0.1)
|
|
)
|
|
|
|
with patch.object(consumer, '_connect', connect_mock):
|
|
await consumer.start()
|
|
|
|
# Simulate connection error
|
|
await consumer._handle_connection_error()
|
|
|
|
# Should have attempted reconnect
|
|
assert connect_mock.call_count >= 2
|
|
|
|
async def test_message_ordering_after_reconnect(self):
|
|
"""Should maintain message ordering after reconnect."""
|
|
received_messages = []
|
|
|
|
async def handler(msg):
|
|
received_messages.append(msg["seq"])
|
|
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=handler
|
|
)
|
|
|
|
# Simulate messages arriving during reconnection
|
|
await consumer._buffer.put({"seq": 1})
|
|
await consumer._buffer.put({"seq": 2})
|
|
|
|
# Process all
|
|
while consumer._buffer.size() > 0:
|
|
await consumer._process_one()
|
|
|
|
assert received_messages == [1, 2]
|
|
|
|
async def test_graceful_shutdown_during_reconnect(self):
|
|
"""Should shutdown gracefully even during reconnection."""
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=AsyncMock(),
|
|
reconnect_policy=ReconnectPolicy(max_retries=100, base_delay=1.0)
|
|
)
|
|
|
|
# Start reconnect loop
|
|
reconnect_task = asyncio.create_task(consumer._reconnect_loop())
|
|
await asyncio.sleep(0.05)
|
|
|
|
# Stop should cancel reconnect
|
|
await consumer.stop()
|
|
|
|
assert reconnect_task.done()
|
|
assert not consumer.is_running
|
|
|
|
|
|
class TestStreamConsumerMetrics:
|
|
"""Tests for consumer metrics and observability."""
|
|
|
|
async def test_message_count_tracking(self):
|
|
"""Should track message counts."""
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=AsyncMock()
|
|
)
|
|
|
|
await consumer._process_message({"id": "1"})
|
|
await consumer._process_message({"id": "2"})
|
|
await consumer._process_message({"id": "3"})
|
|
|
|
assert consumer.metrics.messages_received == 3
|
|
|
|
async def test_error_count_tracking(self):
|
|
"""Should track error counts."""
|
|
handler = AsyncMock(side_effect=Exception("Error"))
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=handler
|
|
)
|
|
|
|
await consumer._process_message({"id": "1"})
|
|
await consumer._process_message({"id": "2"})
|
|
|
|
assert consumer.metrics.errors == 2
|
|
|
|
async def test_latency_tracking(self):
|
|
"""Should track processing latency."""
|
|
async def slow_handler(msg):
|
|
await asyncio.sleep(0.05)
|
|
|
|
consumer = StreamConsumer(
|
|
endpoint="ws://test.example.com",
|
|
message_handler=slow_handler
|
|
)
|
|
|
|
await consumer._process_message({"id": "1"})
|
|
|
|
assert consumer.metrics.avg_latency_ms >= 50
|