diff --git a/gateway/platforms/api_server.py b/gateway/platforms/api_server.py index 19fa5f60d..a27408f4c 100644 --- a/gateway/platforms/api_server.py +++ b/gateway/platforms/api_server.py @@ -380,6 +380,7 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, stream_delta_callback=None, + tool_progress_callback=None, ) -> Any: """ Create an AIAgent instance using the gateway's runtime config. @@ -412,6 +413,7 @@ class APIServerAdapter(BasePlatformAdapter): session_id=session_id, platform="api_server", stream_delta_callback=stream_delta_callback, + tool_progress_callback=tool_progress_callback, ) return agent @@ -514,6 +516,15 @@ class APIServerAdapter(BasePlatformAdapter): if delta is not None: _stream_q.put(delta) + def _on_tool_progress(name, preview, args): + """Inject tool progress into the SSE stream for Open WebUI.""" + if name.startswith("_"): + return # Skip internal events (_thinking) + from agent.display import get_tool_emoji + emoji = get_tool_emoji(name) + label = preview or name + _stream_q.put(f"\n`{emoji} {label}`\n") + # Start agent in background. agent_ref is a mutable container # so the SSE writer can interrupt the agent on client disconnect. agent_ref = [None] @@ -523,6 +534,7 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt=system_prompt, session_id=session_id, stream_delta_callback=_on_delta, + tool_progress_callback=_on_tool_progress, agent_ref=agent_ref, )) @@ -1194,6 +1206,7 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt: Optional[str] = None, session_id: Optional[str] = None, stream_delta_callback=None, + tool_progress_callback=None, agent_ref: Optional[list] = None, ) -> tuple: """ @@ -1214,6 +1227,7 @@ class APIServerAdapter(BasePlatformAdapter): ephemeral_system_prompt=ephemeral_system_prompt, session_id=session_id, stream_delta_callback=stream_delta_callback, + tool_progress_callback=tool_progress_callback, ) if agent_ref is not None: agent_ref[0] = agent diff --git a/tests/gateway/test_api_server.py b/tests/gateway/test_api_server.py index 772dd8b1c..b48ac1af7 100644 --- a/tests/gateway/test_api_server.py +++ b/tests/gateway/test_api_server.py @@ -427,6 +427,81 @@ class TestChatCompletionsEndpoint: assert "Thinking" in body assert " about it..." in body + @pytest.mark.asyncio + async def test_stream_includes_tool_progress(self, adapter): + """tool_progress_callback fires → progress appears in the SSE stream.""" + import asyncio + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + tp_cb = kwargs.get("tool_progress_callback") + # Simulate tool progress before streaming content + if tp_cb: + tp_cb("terminal", "ls -la", {"command": "ls -la"}) + if cb: + await asyncio.sleep(0.05) + cb("Here are the files.") + return ( + {"final_response": "Here are the files.", "messages": [], "api_calls": 1}, + {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "list files"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + assert "[DONE]" in body + # Tool progress message must appear in the stream + assert "ls -la" in body + # Final content must also be present + assert "Here are the files." in body + + @pytest.mark.asyncio + async def test_stream_tool_progress_skips_internal_events(self, adapter): + """Internal events (name starting with _) are not streamed.""" + import asyncio + + app = _create_app(adapter) + async with TestClient(TestServer(app)) as cli: + async def _mock_run_agent(**kwargs): + cb = kwargs.get("stream_delta_callback") + tp_cb = kwargs.get("tool_progress_callback") + if tp_cb: + tp_cb("_thinking", "some internal state", {}) + tp_cb("web_search", "Python docs", {"query": "Python docs"}) + if cb: + await asyncio.sleep(0.05) + cb("Found it.") + return ( + {"final_response": "Found it.", "messages": [], "api_calls": 1}, + {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}, + ) + + with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent): + resp = await cli.post( + "/v1/chat/completions", + json={ + "model": "test", + "messages": [{"role": "user", "content": "search"}], + "stream": True, + }, + ) + assert resp.status == 200 + body = await resp.text() + # Internal _thinking event should NOT appear + assert "some internal state" not in body + # Real tool progress should appear + assert "Python docs" in body + @pytest.mark.asyncio async def test_no_user_message_returns_400(self, adapter): app = _create_app(adapter)