173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
|
|
"""Gateway streaming consumer — bridges sync agent callbacks to async platform delivery.
|
||
|
|
|
||
|
|
The agent fires stream_delta_callback(text) synchronously from its worker thread.
|
||
|
|
GatewayStreamConsumer:
|
||
|
|
1. Receives deltas via on_delta() (thread-safe, sync)
|
||
|
|
2. Queues them to an asyncio task via queue.Queue
|
||
|
|
3. The async run() task buffers, rate-limits, and progressively edits
|
||
|
|
a single message on the target platform
|
||
|
|
|
||
|
|
Design: Uses the edit transport (send initial message, then editMessageText).
|
||
|
|
This is universally supported across Telegram, Discord, and Slack.
|
||
|
|
|
||
|
|
Credit: jobless0x (#774, #1312), OutThisLife (#798), clicksingh (#697).
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import logging
|
||
|
|
import queue
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Any, Optional
|
||
|
|
|
||
|
|
logger = logging.getLogger("gateway.stream_consumer")
|
||
|
|
|
||
|
|
# Sentinel to signal the stream is complete
|
||
|
|
_DONE = object()
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class StreamConsumerConfig:
|
||
|
|
"""Runtime config for a single stream consumer instance."""
|
||
|
|
edit_interval: float = 0.3
|
||
|
|
buffer_threshold: int = 40
|
||
|
|
cursor: str = " ▉"
|
||
|
|
|
||
|
|
|
||
|
|
class GatewayStreamConsumer:
|
||
|
|
"""Async consumer that progressively edits a platform message with streamed tokens.
|
||
|
|
|
||
|
|
Usage::
|
||
|
|
|
||
|
|
consumer = GatewayStreamConsumer(adapter, chat_id, config, metadata=metadata)
|
||
|
|
# Pass consumer.on_delta as stream_delta_callback to AIAgent
|
||
|
|
agent = AIAgent(..., stream_delta_callback=consumer.on_delta)
|
||
|
|
# Start the consumer as an asyncio task
|
||
|
|
task = asyncio.create_task(consumer.run())
|
||
|
|
# ... run agent in thread pool ...
|
||
|
|
consumer.finish() # signal completion
|
||
|
|
await task # wait for final edit
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
adapter: Any,
|
||
|
|
chat_id: str,
|
||
|
|
config: Optional[StreamConsumerConfig] = None,
|
||
|
|
metadata: Optional[dict] = None,
|
||
|
|
):
|
||
|
|
self.adapter = adapter
|
||
|
|
self.chat_id = chat_id
|
||
|
|
self.cfg = config or StreamConsumerConfig()
|
||
|
|
self.metadata = metadata
|
||
|
|
self._queue: queue.Queue = queue.Queue()
|
||
|
|
self._accumulated = ""
|
||
|
|
self._message_id: Optional[str] = None
|
||
|
|
self._already_sent = False
|
||
|
|
self._last_edit_time = 0.0
|
||
|
|
|
||
|
|
@property
|
||
|
|
def already_sent(self) -> bool:
|
||
|
|
"""True if at least one message was sent/edited — signals the base
|
||
|
|
adapter to skip re-sending the final response."""
|
||
|
|
return self._already_sent
|
||
|
|
|
||
|
|
def on_delta(self, text: str) -> None:
|
||
|
|
"""Thread-safe callback — called from the agent's worker thread."""
|
||
|
|
if text:
|
||
|
|
self._queue.put(text)
|
||
|
|
|
||
|
|
def finish(self) -> None:
|
||
|
|
"""Signal that the stream is complete."""
|
||
|
|
self._queue.put(_DONE)
|
||
|
|
|
||
|
|
async def run(self) -> None:
|
||
|
|
"""Async task that drains the queue and edits the platform message."""
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
# Drain all available items from the queue
|
||
|
|
got_done = False
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
item = self._queue.get_nowait()
|
||
|
|
if item is _DONE:
|
||
|
|
got_done = True
|
||
|
|
break
|
||
|
|
self._accumulated += item
|
||
|
|
except queue.Empty:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Decide whether to flush an edit
|
||
|
|
now = time.monotonic()
|
||
|
|
elapsed = now - self._last_edit_time
|
||
|
|
should_edit = (
|
||
|
|
got_done
|
||
|
|
or (elapsed >= self.cfg.edit_interval
|
||
|
|
and len(self._accumulated) > 0)
|
||
|
|
or len(self._accumulated) >= self.cfg.buffer_threshold
|
||
|
|
)
|
||
|
|
|
||
|
|
if should_edit and self._accumulated:
|
||
|
|
display_text = self._accumulated
|
||
|
|
if not got_done:
|
||
|
|
display_text += self.cfg.cursor
|
||
|
|
|
||
|
|
await self._send_or_edit(display_text)
|
||
|
|
self._last_edit_time = time.monotonic()
|
||
|
|
|
||
|
|
if got_done:
|
||
|
|
# Final edit without cursor
|
||
|
|
if self._accumulated and self._message_id:
|
||
|
|
await self._send_or_edit(self._accumulated)
|
||
|
|
return
|
||
|
|
|
||
|
|
await asyncio.sleep(0.05) # Small yield to not busy-loop
|
||
|
|
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
# Best-effort final edit on cancellation
|
||
|
|
if self._accumulated and self._message_id:
|
||
|
|
try:
|
||
|
|
await self._send_or_edit(self._accumulated)
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Stream consumer error: %s", e)
|
||
|
|
|
||
|
|
async def _send_or_edit(self, text: str) -> None:
|
||
|
|
"""Send or edit the streaming message."""
|
||
|
|
try:
|
||
|
|
if self._message_id is not None:
|
||
|
|
# Edit existing message
|
||
|
|
result = await self.adapter.edit_message(
|
||
|
|
chat_id=self.chat_id,
|
||
|
|
message_id=self._message_id,
|
||
|
|
content=text,
|
||
|
|
)
|
||
|
|
if result.success:
|
||
|
|
self._already_sent = True
|
||
|
|
else:
|
||
|
|
# Edit failed — try sending as new message
|
||
|
|
logger.debug("Edit failed, sending new message")
|
||
|
|
result = await self.adapter.send(
|
||
|
|
chat_id=self.chat_id,
|
||
|
|
content=text,
|
||
|
|
metadata=self.metadata,
|
||
|
|
)
|
||
|
|
if result.success and result.message_id:
|
||
|
|
self._message_id = result.message_id
|
||
|
|
self._already_sent = True
|
||
|
|
else:
|
||
|
|
# First message — send new
|
||
|
|
result = await self.adapter.send(
|
||
|
|
chat_id=self.chat_id,
|
||
|
|
content=text,
|
||
|
|
metadata=self.metadata,
|
||
|
|
)
|
||
|
|
if result.success and result.message_id:
|
||
|
|
self._message_id = result.message_id
|
||
|
|
self._already_sent = True
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Stream send/edit error: %s", e)
|