- Introduced a monitoring mechanism in GatewayRunner to detect incoming messages while an agent is active, allowing for graceful interruption and processing of new messages. - Enhanced BasePlatformAdapter to manage active sessions and pending messages, ensuring that new messages can interrupt ongoing tasks effectively. - Improved the handling of pending messages by checking for interrupts and processing them in the correct order, enhancing user experience during message interactions. - Updated the cleanup process for active tasks to ensure proper resource management after interruptions.
366 lines
12 KiB
Python
366 lines
12 KiB
Python
"""
|
|
Base platform adapter interface.
|
|
|
|
All platform adapters (Telegram, Discord, WhatsApp) inherit from this
|
|
and implement the required methods.
|
|
"""
|
|
|
|
import asyncio
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Any, Callable, Awaitable
|
|
from enum import Enum
|
|
|
|
import sys
|
|
sys.path.insert(0, str(__file__).rsplit("/", 3)[0])
|
|
|
|
from gateway.config import Platform, PlatformConfig
|
|
from gateway.session import SessionSource
|
|
|
|
|
|
class MessageType(Enum):
|
|
"""Types of incoming messages."""
|
|
TEXT = "text"
|
|
PHOTO = "photo"
|
|
VIDEO = "video"
|
|
AUDIO = "audio"
|
|
VOICE = "voice"
|
|
DOCUMENT = "document"
|
|
STICKER = "sticker"
|
|
COMMAND = "command" # /command style
|
|
|
|
|
|
@dataclass
|
|
class MessageEvent:
|
|
"""
|
|
Incoming message from a platform.
|
|
|
|
Normalized representation that all adapters produce.
|
|
"""
|
|
# Message content
|
|
text: str
|
|
message_type: MessageType = MessageType.TEXT
|
|
|
|
# Source information
|
|
source: SessionSource = None
|
|
|
|
# Original platform data
|
|
raw_message: Any = None
|
|
message_id: Optional[str] = None
|
|
|
|
# Media attachments
|
|
media_urls: List[str] = field(default_factory=list)
|
|
media_types: List[str] = field(default_factory=list)
|
|
|
|
# Reply context
|
|
reply_to_message_id: Optional[str] = None
|
|
|
|
# Timestamps
|
|
timestamp: datetime = field(default_factory=datetime.now)
|
|
|
|
def is_command(self) -> bool:
|
|
"""Check if this is a command message (e.g., /new, /reset)."""
|
|
return self.text.startswith("/")
|
|
|
|
def get_command(self) -> Optional[str]:
|
|
"""Extract command name if this is a command message."""
|
|
if not self.is_command():
|
|
return None
|
|
# Split on space and get first word, strip the /
|
|
parts = self.text.split(maxsplit=1)
|
|
return parts[0][1:].lower() if parts else None
|
|
|
|
def get_command_args(self) -> str:
|
|
"""Get the arguments after a command."""
|
|
if not self.is_command():
|
|
return self.text
|
|
parts = self.text.split(maxsplit=1)
|
|
return parts[1] if len(parts) > 1 else ""
|
|
|
|
|
|
@dataclass
|
|
class SendResult:
|
|
"""Result of sending a message."""
|
|
success: bool
|
|
message_id: Optional[str] = None
|
|
error: Optional[str] = None
|
|
raw_response: Any = None
|
|
|
|
|
|
# Type for message handlers
|
|
MessageHandler = Callable[[MessageEvent], Awaitable[Optional[str]]]
|
|
|
|
|
|
class BasePlatformAdapter(ABC):
|
|
"""
|
|
Base class for platform adapters.
|
|
|
|
Subclasses implement platform-specific logic for:
|
|
- Connecting and authenticating
|
|
- Receiving messages
|
|
- Sending messages/responses
|
|
- Handling media
|
|
"""
|
|
|
|
def __init__(self, config: PlatformConfig, platform: Platform):
|
|
self.config = config
|
|
self.platform = platform
|
|
self._message_handler: Optional[MessageHandler] = None
|
|
self._running = False
|
|
|
|
# Track active message handlers per session for interrupt support
|
|
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
|
self._active_sessions: Dict[str, asyncio.Event] = {}
|
|
self._pending_messages: Dict[str, MessageEvent] = {}
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
"""Human-readable name for this adapter."""
|
|
return self.platform.value.title()
|
|
|
|
@property
|
|
def is_connected(self) -> bool:
|
|
"""Check if adapter is currently connected."""
|
|
return self._running
|
|
|
|
def set_message_handler(self, handler: MessageHandler) -> None:
|
|
"""
|
|
Set the handler for incoming messages.
|
|
|
|
The handler receives a MessageEvent and should return
|
|
an optional response string.
|
|
"""
|
|
self._message_handler = handler
|
|
|
|
@abstractmethod
|
|
async def connect(self) -> bool:
|
|
"""
|
|
Connect to the platform and start receiving messages.
|
|
|
|
Returns True if connection was successful.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def disconnect(self) -> None:
|
|
"""Disconnect from the platform."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
async def send(
|
|
self,
|
|
chat_id: str,
|
|
content: str,
|
|
reply_to: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None
|
|
) -> SendResult:
|
|
"""
|
|
Send a message to a chat.
|
|
|
|
Args:
|
|
chat_id: The chat/channel ID to send to
|
|
content: Message content (may be markdown)
|
|
reply_to: Optional message ID to reply to
|
|
metadata: Additional platform-specific options
|
|
|
|
Returns:
|
|
SendResult with success status and message ID
|
|
"""
|
|
pass
|
|
|
|
async def send_typing(self, chat_id: str) -> None:
|
|
"""
|
|
Send a typing indicator.
|
|
|
|
Override in subclasses if the platform supports it.
|
|
"""
|
|
pass
|
|
|
|
async def _keep_typing(self, chat_id: str, interval: float = 2.0) -> None:
|
|
"""
|
|
Continuously send typing indicator until cancelled.
|
|
|
|
Telegram/Discord typing status expires after ~5 seconds, so we refresh every 2
|
|
to recover quickly after progress messages interrupt it.
|
|
"""
|
|
try:
|
|
while True:
|
|
await self.send_typing(chat_id)
|
|
await asyncio.sleep(interval)
|
|
except asyncio.CancelledError:
|
|
pass # Normal cancellation when handler completes
|
|
|
|
async def handle_message(self, event: MessageEvent) -> None:
|
|
"""
|
|
Process an incoming message.
|
|
|
|
This method returns quickly by spawning background tasks.
|
|
This allows new messages to be processed even while an agent is running,
|
|
enabling interruption support.
|
|
"""
|
|
if not self._message_handler:
|
|
return
|
|
|
|
session_key = event.source.chat_id
|
|
|
|
# Check if there's already an active handler for this session
|
|
if session_key in self._active_sessions:
|
|
# Store this as a pending message - it will interrupt the running agent
|
|
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
|
self._pending_messages[session_key] = event
|
|
# Signal the interrupt (the processing task checks this)
|
|
self._active_sessions[session_key].set()
|
|
return # Don't process now - will be handled after current task finishes
|
|
|
|
# Spawn background task to process this message
|
|
asyncio.create_task(self._process_message_background(event, session_key))
|
|
|
|
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
|
"""Background task that actually processes the message."""
|
|
# Create interrupt event for this session
|
|
interrupt_event = asyncio.Event()
|
|
self._active_sessions[session_key] = interrupt_event
|
|
|
|
# Start continuous typing indicator (refreshes every 2 seconds)
|
|
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
|
|
|
try:
|
|
# Call the handler (this can take a while with tool calls)
|
|
response = await self._message_handler(event)
|
|
|
|
# Send response if any
|
|
if response:
|
|
result = await self.send(
|
|
chat_id=event.source.chat_id,
|
|
content=response,
|
|
reply_to=event.message_id
|
|
)
|
|
|
|
# Log send failures (don't raise - user already saw tool progress)
|
|
if not result.success:
|
|
print(f"[{self.name}] Failed to send response: {result.error}")
|
|
# Try sending without markdown as fallback
|
|
fallback_result = await self.send(
|
|
chat_id=event.source.chat_id,
|
|
content=f"(Response formatting failed, plain text:)\n\n{response[:3500]}",
|
|
reply_to=event.message_id
|
|
)
|
|
if not fallback_result.success:
|
|
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
|
|
|
# Check if there's a pending message that was queued during our processing
|
|
if session_key in self._pending_messages:
|
|
pending_event = self._pending_messages.pop(session_key)
|
|
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
|
# Clean up current session before processing pending
|
|
if session_key in self._active_sessions:
|
|
del self._active_sessions[session_key]
|
|
typing_task.cancel()
|
|
try:
|
|
await typing_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Process pending message in new background task
|
|
await self._process_message_background(pending_event, session_key)
|
|
return # Already cleaned up
|
|
|
|
except Exception as e:
|
|
print(f"[{self.name}] Error handling message: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
finally:
|
|
# Stop typing indicator
|
|
typing_task.cancel()
|
|
try:
|
|
await typing_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
# Clean up session tracking
|
|
if session_key in self._active_sessions:
|
|
del self._active_sessions[session_key]
|
|
|
|
def has_pending_interrupt(self, session_key: str) -> bool:
|
|
"""Check if there's a pending interrupt for a session."""
|
|
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
|
|
|
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
|
"""Get and clear any pending message for a session."""
|
|
return self._pending_messages.get(session_key)
|
|
|
|
def build_source(
|
|
self,
|
|
chat_id: str,
|
|
chat_name: Optional[str] = None,
|
|
chat_type: str = "dm",
|
|
user_id: Optional[str] = None,
|
|
user_name: Optional[str] = None,
|
|
thread_id: Optional[str] = None
|
|
) -> SessionSource:
|
|
"""Helper to build a SessionSource for this platform."""
|
|
return SessionSource(
|
|
platform=self.platform,
|
|
chat_id=str(chat_id),
|
|
chat_name=chat_name,
|
|
chat_type=chat_type,
|
|
user_id=str(user_id) if user_id else None,
|
|
user_name=user_name,
|
|
thread_id=str(thread_id) if thread_id else None,
|
|
)
|
|
|
|
@abstractmethod
|
|
async def get_chat_info(self, chat_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Get information about a chat/channel.
|
|
|
|
Returns dict with at least:
|
|
- name: Chat name
|
|
- type: "dm", "group", "channel"
|
|
"""
|
|
pass
|
|
|
|
def format_message(self, content: str) -> str:
|
|
"""
|
|
Format a message for this platform.
|
|
|
|
Override in subclasses to handle platform-specific formatting
|
|
(e.g., Telegram MarkdownV2, Discord markdown).
|
|
|
|
Default implementation returns content as-is.
|
|
"""
|
|
return content
|
|
|
|
def truncate_message(self, content: str, max_length: int = 4096) -> List[str]:
|
|
"""
|
|
Split a long message into chunks.
|
|
|
|
Args:
|
|
content: The full message content
|
|
max_length: Maximum length per chunk (platform-specific)
|
|
|
|
Returns:
|
|
List of message chunks
|
|
"""
|
|
if len(content) <= max_length:
|
|
return [content]
|
|
|
|
chunks = []
|
|
while content:
|
|
if len(content) <= max_length:
|
|
chunks.append(content)
|
|
break
|
|
|
|
# Try to split at a newline
|
|
split_idx = content.rfind("\n", 0, max_length)
|
|
if split_idx == -1:
|
|
# No newline, split at space
|
|
split_idx = content.rfind(" ", 0, max_length)
|
|
if split_idx == -1:
|
|
# No space either, hard split
|
|
split_idx = max_length
|
|
|
|
chunks.append(content[:split_idx])
|
|
content = content[split_idx:].lstrip()
|
|
|
|
return chunks
|