166 lines
5.8 KiB
Python
166 lines
5.8 KiB
Python
|
|
"""Tests for /queue message consumption after normal agent completion.
|
||
|
|
|
||
|
|
Verifies that messages queued via /queue (which store in
|
||
|
|
adapter._pending_messages WITHOUT triggering an interrupt) are consumed
|
||
|
|
after the agent finishes its current task — not silently dropped.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from gateway.platforms.base import (
|
||
|
|
BasePlatformAdapter,
|
||
|
|
MessageEvent,
|
||
|
|
MessageType,
|
||
|
|
PlatformConfig,
|
||
|
|
Platform,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Minimal adapter for testing pending message storage
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class _StubAdapter(BasePlatformAdapter):
|
||
|
|
def __init__(self):
|
||
|
|
super().__init__(PlatformConfig(enabled=True, token="test"), Platform.TELEGRAM)
|
||
|
|
|
||
|
|
async def connect(self) -> bool:
|
||
|
|
return True
|
||
|
|
|
||
|
|
async def disconnect(self) -> None:
|
||
|
|
self._mark_disconnected()
|
||
|
|
|
||
|
|
async def send(self, chat_id, content, reply_to=None, metadata=None):
|
||
|
|
from gateway.platforms.base import SendResult
|
||
|
|
return SendResult(success=True, message_id="msg-1")
|
||
|
|
|
||
|
|
async def get_chat_info(self, chat_id):
|
||
|
|
return {"id": chat_id, "type": "dm"}
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Tests
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
|
||
|
|
class TestQueueMessageStorage:
|
||
|
|
"""Verify /queue stores messages correctly in adapter._pending_messages."""
|
||
|
|
|
||
|
|
def test_queue_stores_message_in_pending(self):
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
event = MessageEvent(
|
||
|
|
text="do this next",
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||
|
|
message_id="q1",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
|
||
|
|
assert session_key in adapter._pending_messages
|
||
|
|
assert adapter._pending_messages[session_key].text == "do this next"
|
||
|
|
|
||
|
|
def test_get_pending_message_consumes_and_clears(self):
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
event = MessageEvent(
|
||
|
|
text="queued prompt",
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(chat_id="123", platform=Platform.TELEGRAM),
|
||
|
|
message_id="q2",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
|
||
|
|
retrieved = adapter.get_pending_message(session_key)
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.text == "queued prompt"
|
||
|
|
# Should be consumed (cleared)
|
||
|
|
assert adapter.get_pending_message(session_key) is None
|
||
|
|
|
||
|
|
def test_queue_does_not_set_interrupt_event(self):
|
||
|
|
"""The whole point of /queue — no interrupt signal."""
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
|
||
|
|
# Simulate an active session (agent running)
|
||
|
|
adapter._active_sessions[session_key] = asyncio.Event()
|
||
|
|
|
||
|
|
# Store a queued message (what /queue does)
|
||
|
|
event = MessageEvent(
|
||
|
|
text="queued",
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(),
|
||
|
|
message_id="q3",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
|
||
|
|
# The interrupt event should NOT be set
|
||
|
|
assert not adapter._active_sessions[session_key].is_set()
|
||
|
|
assert not adapter.has_pending_interrupt(session_key)
|
||
|
|
|
||
|
|
def test_regular_message_sets_interrupt_event(self):
|
||
|
|
"""Contrast: regular messages DO trigger interrupt."""
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
|
||
|
|
adapter._active_sessions[session_key] = asyncio.Event()
|
||
|
|
|
||
|
|
# Simulate regular message arrival (what handle_message does)
|
||
|
|
event = MessageEvent(
|
||
|
|
text="new message",
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(),
|
||
|
|
message_id="m1",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
adapter._active_sessions[session_key].set() # this is what handle_message does
|
||
|
|
|
||
|
|
assert adapter.has_pending_interrupt(session_key)
|
||
|
|
|
||
|
|
|
||
|
|
class TestQueueConsumptionAfterCompletion:
|
||
|
|
"""Verify that pending messages are consumed after normal completion."""
|
||
|
|
|
||
|
|
def test_pending_message_available_after_normal_completion(self):
|
||
|
|
"""After agent finishes without interrupt, pending message should
|
||
|
|
still be retrievable from adapter._pending_messages."""
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
|
||
|
|
# Simulate: agent starts, /queue stores a message, agent finishes
|
||
|
|
adapter._active_sessions[session_key] = asyncio.Event()
|
||
|
|
event = MessageEvent(
|
||
|
|
text="process this after",
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(),
|
||
|
|
message_id="q4",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
|
||
|
|
# Agent finishes (no interrupt)
|
||
|
|
del adapter._active_sessions[session_key]
|
||
|
|
|
||
|
|
# The queued message should still be retrievable
|
||
|
|
retrieved = adapter.get_pending_message(session_key)
|
||
|
|
assert retrieved is not None
|
||
|
|
assert retrieved.text == "process this after"
|
||
|
|
|
||
|
|
def test_multiple_queues_last_one_wins(self):
|
||
|
|
"""If user /queue's multiple times, last message overwrites."""
|
||
|
|
adapter = _StubAdapter()
|
||
|
|
session_key = "telegram:user:123"
|
||
|
|
|
||
|
|
for text in ["first", "second", "third"]:
|
||
|
|
event = MessageEvent(
|
||
|
|
text=text,
|
||
|
|
message_type=MessageType.TEXT,
|
||
|
|
source=MagicMock(),
|
||
|
|
message_id=f"q-{text}",
|
||
|
|
)
|
||
|
|
adapter._pending_messages[session_key] = event
|
||
|
|
|
||
|
|
retrieved = adapter.get_pending_message(session_key)
|
||
|
|
assert retrieved.text == "third"
|