172 lines
5.2 KiB
Python
172 lines
5.2 KiB
Python
|
|
"""Callback factories for bridging AIAgent events to ACP notifications.
|
||
|
|
|
||
|
|
Each factory returns a callable with the signature that AIAgent expects
|
||
|
|
for its callbacks. Internally, the callbacks push ACP session updates
|
||
|
|
to the client via ``conn.session_update()`` using
|
||
|
|
``asyncio.run_coroutine_threadsafe()`` (since AIAgent runs in a worker
|
||
|
|
thread while the event loop lives on the main thread).
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import logging
|
||
|
|
from collections import defaultdict, deque
|
||
|
|
from typing import Any, Callable, Deque, Dict
|
||
|
|
|
||
|
|
import acp
|
||
|
|
|
||
|
|
from .tools import (
|
||
|
|
build_tool_complete,
|
||
|
|
build_tool_start,
|
||
|
|
make_tool_call_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def _send_update(
|
||
|
|
conn: acp.Client,
|
||
|
|
session_id: str,
|
||
|
|
loop: asyncio.AbstractEventLoop,
|
||
|
|
update: Any,
|
||
|
|
) -> None:
|
||
|
|
"""Fire-and-forget an ACP session update from a worker thread."""
|
||
|
|
try:
|
||
|
|
future = asyncio.run_coroutine_threadsafe(
|
||
|
|
conn.session_update(session_id, update), loop
|
||
|
|
)
|
||
|
|
future.result(timeout=5)
|
||
|
|
except Exception:
|
||
|
|
logger.debug("Failed to send ACP update", exc_info=True)
|
||
|
|
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Tool progress callback
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def make_tool_progress_cb(
|
||
|
|
conn: acp.Client,
|
||
|
|
session_id: str,
|
||
|
|
loop: asyncio.AbstractEventLoop,
|
||
|
|
tool_call_ids: Dict[str, Deque[str]],
|
||
|
|
) -> Callable:
|
||
|
|
"""Create a ``tool_progress_callback`` for AIAgent.
|
||
|
|
|
||
|
|
Signature expected by AIAgent::
|
||
|
|
|
||
|
|
tool_progress_callback(name: str, preview: str, args: dict)
|
||
|
|
|
||
|
|
Emits ``ToolCallStart`` for each tool invocation and tracks IDs in a FIFO
|
||
|
|
queue per tool name so duplicate/parallel same-name calls still complete
|
||
|
|
against the correct ACP tool call.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _tool_progress(name: str, preview: str, args: Any = None) -> None:
|
||
|
|
if isinstance(args, str):
|
||
|
|
try:
|
||
|
|
args = json.loads(args)
|
||
|
|
except (json.JSONDecodeError, TypeError):
|
||
|
|
args = {"raw": args}
|
||
|
|
if not isinstance(args, dict):
|
||
|
|
args = {}
|
||
|
|
|
||
|
|
tc_id = make_tool_call_id()
|
||
|
|
queue = tool_call_ids.get(name)
|
||
|
|
if queue is None:
|
||
|
|
queue = deque()
|
||
|
|
tool_call_ids[name] = queue
|
||
|
|
elif isinstance(queue, str):
|
||
|
|
queue = deque([queue])
|
||
|
|
tool_call_ids[name] = queue
|
||
|
|
queue.append(tc_id)
|
||
|
|
|
||
|
|
update = build_tool_start(tc_id, name, args)
|
||
|
|
_send_update(conn, session_id, loop, update)
|
||
|
|
|
||
|
|
return _tool_progress
|
||
|
|
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Thinking callback
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def make_thinking_cb(
|
||
|
|
conn: acp.Client,
|
||
|
|
session_id: str,
|
||
|
|
loop: asyncio.AbstractEventLoop,
|
||
|
|
) -> Callable:
|
||
|
|
"""Create a ``thinking_callback`` for AIAgent."""
|
||
|
|
|
||
|
|
def _thinking(text: str) -> None:
|
||
|
|
if not text:
|
||
|
|
return
|
||
|
|
update = acp.update_agent_thought_text(text)
|
||
|
|
_send_update(conn, session_id, loop, update)
|
||
|
|
|
||
|
|
return _thinking
|
||
|
|
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Step callback
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def make_step_cb(
|
||
|
|
conn: acp.Client,
|
||
|
|
session_id: str,
|
||
|
|
loop: asyncio.AbstractEventLoop,
|
||
|
|
tool_call_ids: Dict[str, Deque[str]],
|
||
|
|
) -> Callable:
|
||
|
|
"""Create a ``step_callback`` for AIAgent.
|
||
|
|
|
||
|
|
Signature expected by AIAgent::
|
||
|
|
|
||
|
|
step_callback(api_call_count: int, prev_tools: list)
|
||
|
|
"""
|
||
|
|
|
||
|
|
def _step(api_call_count: int, prev_tools: Any = None) -> None:
|
||
|
|
if prev_tools and isinstance(prev_tools, list):
|
||
|
|
for tool_info in prev_tools:
|
||
|
|
tool_name = None
|
||
|
|
result = None
|
||
|
|
|
||
|
|
if isinstance(tool_info, dict):
|
||
|
|
tool_name = tool_info.get("name") or tool_info.get("function_name")
|
||
|
|
result = tool_info.get("result") or tool_info.get("output")
|
||
|
|
elif isinstance(tool_info, str):
|
||
|
|
tool_name = tool_info
|
||
|
|
|
||
|
|
queue = tool_call_ids.get(tool_name or "")
|
||
|
|
if isinstance(queue, str):
|
||
|
|
queue = deque([queue])
|
||
|
|
tool_call_ids[tool_name] = queue
|
||
|
|
if tool_name and queue:
|
||
|
|
tc_id = queue.popleft()
|
||
|
|
update = build_tool_complete(
|
||
|
|
tc_id, tool_name, result=str(result) if result is not None else None
|
||
|
|
)
|
||
|
|
_send_update(conn, session_id, loop, update)
|
||
|
|
if not queue:
|
||
|
|
tool_call_ids.pop(tool_name, None)
|
||
|
|
|
||
|
|
return _step
|
||
|
|
|
||
|
|
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
# Agent message callback
|
||
|
|
# ------------------------------------------------------------------
|
||
|
|
|
||
|
|
def make_message_cb(
|
||
|
|
conn: acp.Client,
|
||
|
|
session_id: str,
|
||
|
|
loop: asyncio.AbstractEventLoop,
|
||
|
|
) -> Callable:
|
||
|
|
"""Create a callback that streams agent response text to the editor."""
|
||
|
|
|
||
|
|
def _message(text: str) -> None:
|
||
|
|
if not text:
|
||
|
|
return
|
||
|
|
update = acp.update_agent_message_text(text)
|
||
|
|
_send_update(conn, session_id, loop, update)
|
||
|
|
|
||
|
|
return _message
|