Files
hermes-agent/tests/gateway/test_stream_consumer.py
Allegro 10271c6b44
Some checks failed
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Failing after 25s
Tests / test (pull_request) Failing after 24s
Docker Build and Publish / build-and-push (pull_request) Failing after 35s
security: fix command injection vulnerabilities (CVSS 9.8)
Replace shell=True with list-based subprocess execution to prevent
command injection via malicious user input.

Changes:
- tools/transcription_tools.py: Use shlex.split() + shell=False
- tools/environments/docker.py: List-based commands with container ID validation

Fixes CVE-level vulnerability where malicious file paths or container IDs
could inject arbitrary commands.

CVSS: 9.8 (Critical)
Refs: V-001 in SECURITY_AUDIT_REPORT.md
2026-03-30 23:15:11 +00:00

375 lines
12 KiB
Python

"""Tests for gateway/stream_consumer.py - Stream consumption and backpressure.
Tests message streaming, backpressure handling, and reconnection logic.
"""
import pytest
import asyncio
from unittest.mock import patch, MagicMock, AsyncMock
from types import SimpleNamespace
try:
from gateway.stream_consumer import (
StreamConsumer,
BackpressureStrategy,
MessageBuffer,
ReconnectPolicy,
StreamError,
)
HAS_MODULE = True
except ImportError:
HAS_MODULE = False
pytestmark = [
pytest.mark.skipif(not HAS_MODULE, reason="stream_consumer module not found"),
pytest.mark.asyncio,
]
class TestMessageBuffer:
"""Tests for message buffering."""
async def test_buffer_basic_operations(self):
"""Should support basic put/get operations."""
buffer = MessageBuffer(max_size=100)
await buffer.put("message1")
await buffer.put("message2")
assert buffer.size() == 2
msg1 = await buffer.get()
msg2 = await buffer.get()
assert msg1 == "message1"
assert msg2 == "message2"
async def test_buffer_respects_max_size(self):
"""Should block put when buffer is full."""
buffer = MessageBuffer(max_size=2)
await buffer.put("msg1")
await buffer.put("msg2")
# Third put should block
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(buffer.put("msg3"), timeout=0.1)
async def test_buffer_clear(self):
"""Should clear all messages."""
buffer = MessageBuffer(max_size=100)
await buffer.put("msg1")
await buffer.put("msg2")
buffer.clear()
assert buffer.size() == 0
async def test_buffer_peek(self):
"""Should peek at next message without removing."""
buffer = MessageBuffer(max_size=100)
await buffer.put("msg1")
peeked = buffer.peek()
assert peeked == "msg1"
assert buffer.size() == 1 # Not removed
class TestBackpressureStrategies:
"""Tests for backpressure handling strategies."""
async def test_drop_oldest_strategy(self):
"""Should drop oldest messages when buffer full."""
strategy = BackpressureStrategy.DROP_OLDEST
buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy)
await buffer.put("old1")
await buffer.put("old2")
await buffer.put("old3")
# Add new message - should drop oldest
await buffer.put_with_backpressure("new")
assert buffer.size() == 3
assert "old1" not in list(buffer.items())
assert "new" in list(buffer.items())
async def test_drop_newest_strategy(self):
"""Should drop newest messages when buffer full."""
strategy = BackpressureStrategy.DROP_NEWEST
buffer = MessageBuffer(max_size=3, backpressure_strategy=strategy)
await buffer.put("msg1")
await buffer.put("msg2")
await buffer.put("msg3")
# Try to add new message - should be dropped
result = await buffer.put_with_backpressure("new")
assert buffer.size() == 3
assert "new" not in list(buffer.items())
assert result is False # Indicate message was dropped
async def test_block_strategy(self):
"""Should block producer when buffer full."""
strategy = BackpressureStrategy.BLOCK
buffer = MessageBuffer(max_size=2, backpressure_strategy=strategy)
await buffer.put("msg1")
await buffer.put("msg2")
# Start put in background
put_task = asyncio.create_task(buffer.put_with_backpressure("msg3"))
# Should be blocked
await asyncio.sleep(0.05)
assert not put_task.done()
# Remove item - should unblock
await buffer.get()
await asyncio.wait_for(put_task, timeout=0.1)
assert buffer.size() == 2
class TestStreamConsumer:
"""Tests for stream consumer functionality."""
async def test_consumer_start_stop(self):
"""Should start and stop cleanly."""
consumer = StreamConsumer(
endpoint="ws://test.example.com/stream",
message_handler=AsyncMock()
)
with patch.object(consumer, '_connect', new_callable=AsyncMock):
await consumer.start()
assert consumer.is_running
await consumer.stop()
assert not consumer.is_running
async def test_message_handler_invocation(self):
"""Should invoke message handler for each message."""
handler = AsyncMock()
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=handler
)
test_message = {"id": "1", "content": "test"}
await consumer._process_message(test_message)
handler.assert_called_once_with(test_message)
async def test_message_batching(self):
"""Should batch messages when batch_size configured."""
handler = AsyncMock()
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=handler,
batch_size=3,
batch_timeout=1.0
)
# Add messages
await consumer._buffer.put({"id": "1"})
await consumer._buffer.put({"id": "2"})
# Should not invoke handler yet
handler.assert_not_called()
# Add third message - should trigger batch
await consumer._buffer.put({"id": "3"})
await consumer._flush_batch()
handler.assert_called_once()
assert len(handler.call_args[0][0]) == 3
async def test_error_handling(self):
"""Should handle handler errors gracefully."""
handler = AsyncMock(side_effect=Exception("Handler error"))
error_callback = AsyncMock()
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=handler,
error_handler=error_callback
)
await consumer._process_message({"id": "1"})
error_callback.assert_called_once()
assert consumer.is_running # Should continue running
class TestReconnectPolicy:
"""Tests for reconnection logic."""
def test_exponential_backoff(self):
"""Should use exponential backoff for retries."""
policy = ReconnectPolicy(
max_retries=5,
base_delay=1.0,
max_delay=30.0,
exponential_base=2.0
)
delays = [policy.get_delay(attempt) for attempt in range(5)]
assert delays[0] == 1.0
assert delays[1] == 2.0
assert delays[2] == 4.0
assert delays[3] == 8.0
assert delays[4] == 16.0 # Capped below max_delay
def test_max_delay_cap(self):
"""Should cap delay at max_delay."""
policy = ReconnectPolicy(
max_retries=10,
base_delay=1.0,
max_delay=5.0,
exponential_base=2.0
)
delay = policy.get_delay(attempt=10)
assert delay <= 5.0
def test_jitter_addition(self):
"""Should add jitter to prevent thundering herd."""
policy = ReconnectPolicy(
max_retries=5,
base_delay=1.0,
jitter=True,
jitter_range=(0.0, 0.5)
)
delays = [policy.get_delay(0) for _ in range(10)]
# All delays should be different (with high probability)
assert len(set(delays)) > 1
# All should be within expected range
assert all(1.0 <= d <= 1.5 for d in delays)
def test_retry_exhaustion(self):
"""Should indicate when retries exhausted."""
policy = ReconnectPolicy(max_retries=3)
assert policy.should_retry(0) is True
assert policy.should_retry(1) is True
assert policy.should_retry(2) is True
assert policy.should_retry(3) is False
assert policy.should_retry(4) is False
class TestStreamConsumerReconnect:
"""Tests for consumer reconnection behavior."""
async def test_reconnect_on_connection_error(self):
"""Should reconnect on connection error."""
connect_mock = AsyncMock(side_effect=[
Exception("Connection failed"),
MagicMock(), # Success on second try
])
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=AsyncMock(),
reconnect_policy=ReconnectPolicy(max_retries=3, base_delay=0.1)
)
with patch.object(consumer, '_connect', connect_mock):
await consumer.start()
# Simulate connection error
await consumer._handle_connection_error()
# Should have attempted reconnect
assert connect_mock.call_count >= 2
async def test_message_ordering_after_reconnect(self):
"""Should maintain message ordering after reconnect."""
received_messages = []
async def handler(msg):
received_messages.append(msg["seq"])
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=handler
)
# Simulate messages arriving during reconnection
await consumer._buffer.put({"seq": 1})
await consumer._buffer.put({"seq": 2})
# Process all
while consumer._buffer.size() > 0:
await consumer._process_one()
assert received_messages == [1, 2]
async def test_graceful_shutdown_during_reconnect(self):
"""Should shutdown gracefully even during reconnection."""
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=AsyncMock(),
reconnect_policy=ReconnectPolicy(max_retries=100, base_delay=1.0)
)
# Start reconnect loop
reconnect_task = asyncio.create_task(consumer._reconnect_loop())
await asyncio.sleep(0.05)
# Stop should cancel reconnect
await consumer.stop()
assert reconnect_task.done()
assert not consumer.is_running
class TestStreamConsumerMetrics:
"""Tests for consumer metrics and observability."""
async def test_message_count_tracking(self):
"""Should track message counts."""
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=AsyncMock()
)
await consumer._process_message({"id": "1"})
await consumer._process_message({"id": "2"})
await consumer._process_message({"id": "3"})
assert consumer.metrics.messages_received == 3
async def test_error_count_tracking(self):
"""Should track error counts."""
handler = AsyncMock(side_effect=Exception("Error"))
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=handler
)
await consumer._process_message({"id": "1"})
await consumer._process_message({"id": "2"})
assert consumer.metrics.errors == 2
async def test_latency_tracking(self):
"""Should track processing latency."""
async def slow_handler(msg):
await asyncio.sleep(0.05)
consumer = StreamConsumer(
endpoint="ws://test.example.com",
message_handler=slow_handler
)
await consumer._process_message({"id": "1"})
assert consumer.metrics.avg_latency_ms >= 50