Implement interrupt handling for message processing in GatewayRunner and BasePlatformAdapter
- Introduced a monitoring mechanism in GatewayRunner to detect incoming messages while an agent is active, allowing for graceful interruption and processing of new messages. - Enhanced BasePlatformAdapter to manage active sessions and pending messages, ensuring that new messages can interrupt ongoing tasks effectively. - Improved the handling of pending messages by checking for interrupts and processing them in the correct order, enhancing user experience during message interactions. - Updated the cleanup process for active tasks to ensure proper resource management after interruptions.
This commit is contained in:
@@ -108,6 +108,11 @@ class BasePlatformAdapter(ABC):
|
||||
self.platform = platform
|
||||
self._message_handler: Optional[MessageHandler] = None
|
||||
self._running = False
|
||||
|
||||
# Track active message handlers per session for interrupt support
|
||||
# Key: session_key (e.g., chat_id), Value: (event, asyncio.Event for interrupt)
|
||||
self._active_sessions: Dict[str, asyncio.Event] = {}
|
||||
self._pending_messages: Dict[str, MessageEvent] = {}
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -190,12 +195,33 @@ class BasePlatformAdapter(ABC):
|
||||
"""
|
||||
Process an incoming message.
|
||||
|
||||
Calls the registered message handler and sends the response.
|
||||
Keeps typing indicator active throughout processing.
|
||||
This method returns quickly by spawning background tasks.
|
||||
This allows new messages to be processed even while an agent is running,
|
||||
enabling interruption support.
|
||||
"""
|
||||
if not self._message_handler:
|
||||
return
|
||||
|
||||
session_key = event.source.chat_id
|
||||
|
||||
# Check if there's already an active handler for this session
|
||||
if session_key in self._active_sessions:
|
||||
# Store this as a pending message - it will interrupt the running agent
|
||||
print(f"[{self.name}] ⚡ New message while session {session_key} is active - triggering interrupt")
|
||||
self._pending_messages[session_key] = event
|
||||
# Signal the interrupt (the processing task checks this)
|
||||
self._active_sessions[session_key].set()
|
||||
return # Don't process now - will be handled after current task finishes
|
||||
|
||||
# Spawn background task to process this message
|
||||
asyncio.create_task(self._process_message_background(event, session_key))
|
||||
|
||||
async def _process_message_background(self, event: MessageEvent, session_key: str) -> None:
|
||||
"""Background task that actually processes the message."""
|
||||
# Create interrupt event for this session
|
||||
interrupt_event = asyncio.Event()
|
||||
self._active_sessions[session_key] = interrupt_event
|
||||
|
||||
# Start continuous typing indicator (refreshes every 2 seconds)
|
||||
typing_task = asyncio.create_task(self._keep_typing(event.source.chat_id))
|
||||
|
||||
@@ -222,6 +248,23 @@ class BasePlatformAdapter(ABC):
|
||||
)
|
||||
if not fallback_result.success:
|
||||
print(f"[{self.name}] Fallback send also failed: {fallback_result.error}")
|
||||
|
||||
# Check if there's a pending message that was queued during our processing
|
||||
if session_key in self._pending_messages:
|
||||
pending_event = self._pending_messages.pop(session_key)
|
||||
print(f"[{self.name}] 📨 Processing queued message from interrupt")
|
||||
# Clean up current session before processing pending
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
typing_task.cancel()
|
||||
try:
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Process pending message in new background task
|
||||
await self._process_message_background(pending_event, session_key)
|
||||
return # Already cleaned up
|
||||
|
||||
except Exception as e:
|
||||
print(f"[{self.name}] Error handling message: {e}")
|
||||
import traceback
|
||||
@@ -233,6 +276,17 @@ class BasePlatformAdapter(ABC):
|
||||
await typing_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
# Clean up session tracking
|
||||
if session_key in self._active_sessions:
|
||||
del self._active_sessions[session_key]
|
||||
|
||||
def has_pending_interrupt(self, session_key: str) -> bool:
|
||||
"""Check if there's a pending interrupt for a session."""
|
||||
return session_key in self._active_sessions and self._active_sessions[session_key].is_set()
|
||||
|
||||
def get_pending_message(self, session_key: str) -> Optional[MessageEvent]:
|
||||
"""Get and clear any pending message for a session."""
|
||||
return self._pending_messages.get(session_key)
|
||||
|
||||
def build_source(
|
||||
self,
|
||||
|
||||
@@ -531,6 +531,27 @@ class GatewayRunner:
|
||||
|
||||
tracking_task = asyncio.create_task(track_agent())
|
||||
|
||||
# Monitor for interrupts from the adapter (new messages arriving)
|
||||
async def monitor_for_interrupt():
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if not adapter:
|
||||
return
|
||||
|
||||
chat_id = source.chat_id
|
||||
while True:
|
||||
await asyncio.sleep(0.2) # Check every 200ms
|
||||
# Check if adapter has a pending interrupt for this session
|
||||
if hasattr(adapter, 'has_pending_interrupt') and adapter.has_pending_interrupt(chat_id):
|
||||
agent = agent_holder[0]
|
||||
if agent:
|
||||
pending_event = adapter.get_pending_message(chat_id)
|
||||
pending_text = pending_event.text if pending_event else None
|
||||
print(f"[gateway] ⚡ Interrupt detected from adapter, signaling agent...")
|
||||
agent.interrupt(pending_text)
|
||||
break
|
||||
|
||||
interrupt_monitor = asyncio.create_task(monitor_for_interrupt())
|
||||
|
||||
try:
|
||||
# Run in thread pool to not block
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -538,42 +559,55 @@ class GatewayRunner:
|
||||
|
||||
# Check if we were interrupted and have a pending message
|
||||
result = result_holder[0]
|
||||
if result and result.get("interrupted") and session_key:
|
||||
pending = self._pending_messages.pop(session_key, None)
|
||||
if pending:
|
||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||
# Add an indicator to the response
|
||||
if response:
|
||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||
|
||||
# Send the interrupted response first
|
||||
adapter = self.adapters.get(source.platform)
|
||||
if adapter and response:
|
||||
await adapter.send(chat_id=source.chat_id, content=response)
|
||||
|
||||
# Now process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
context_prompt=context_prompt,
|
||||
history=updated_history,
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key
|
||||
)
|
||||
adapter = self.adapters.get(source.platform)
|
||||
|
||||
# Get pending message from adapter if interrupted
|
||||
pending = None
|
||||
if result and result.get("interrupted") and adapter:
|
||||
pending_event = adapter.get_pending_message(source.chat_id)
|
||||
if pending_event:
|
||||
pending = pending_event.text
|
||||
elif result.get("interrupt_message"):
|
||||
pending = result.get("interrupt_message")
|
||||
|
||||
if pending:
|
||||
print(f"[gateway] 📨 Processing interrupted message: '{pending[:40]}...'")
|
||||
# Add an indicator to the response
|
||||
if response:
|
||||
response = response + "\n\n---\n_[Interrupted - processing your new message]_"
|
||||
|
||||
# Send the interrupted response first
|
||||
if adapter and response:
|
||||
await adapter.send(chat_id=source.chat_id, content=response)
|
||||
|
||||
# Now process the pending message with updated history
|
||||
updated_history = result.get("messages", history)
|
||||
return await self._run_agent(
|
||||
message=pending,
|
||||
context_prompt=context_prompt,
|
||||
history=updated_history,
|
||||
source=source,
|
||||
session_id=session_id,
|
||||
session_key=session_key
|
||||
)
|
||||
finally:
|
||||
# Stop progress sender
|
||||
# Stop progress sender and interrupt monitor
|
||||
if progress_task:
|
||||
progress_task.cancel()
|
||||
interrupt_monitor.cancel()
|
||||
|
||||
# Clean up tracking
|
||||
tracking_task.cancel()
|
||||
if session_key and session_key in self._running_agents:
|
||||
del self._running_agents[session_key]
|
||||
try:
|
||||
await progress_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Wait for cancelled tasks
|
||||
for task in [progress_task, interrupt_monitor, tracking_task]:
|
||||
if task:
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
|
||||
Reference in New Issue
Block a user