fix(api-server): keep chat-completions SSE alive
This commit is contained in:
@@ -53,6 +53,7 @@ DEFAULT_HOST = "127.0.0.1"
|
||||
DEFAULT_PORT = 8642
|
||||
MAX_STORED_RESPONSES = 100
|
||||
MAX_REQUEST_BYTES = 1_000_000 # 1 MB default limit for POST bodies
|
||||
CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS = 30.0
|
||||
|
||||
|
||||
def check_api_server_requirements() -> bool:
|
||||
@@ -762,7 +763,11 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"""
|
||||
import queue as _q
|
||||
|
||||
sse_headers = {"Content-Type": "text/event-stream", "Cache-Control": "no-cache"}
|
||||
sse_headers = {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"X-Accel-Buffering": "no",
|
||||
}
|
||||
# CORS middleware can't inject headers into StreamResponse after
|
||||
# prepare() flushes them, so resolve CORS headers up front.
|
||||
origin = request.headers.get("Origin", "")
|
||||
@@ -775,6 +780,8 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
await response.prepare(request)
|
||||
|
||||
try:
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Role chunk
|
||||
role_chunk = {
|
||||
"id": completion_id, "object": "chat.completion.chunk",
|
||||
@@ -782,6 +789,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"choices": [{"index": 0, "delta": {"role": "assistant"}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(role_chunk)}\n\n".encode())
|
||||
last_activity = time.monotonic()
|
||||
|
||||
# Helper — route a queue item to the correct SSE event.
|
||||
async def _emit(item):
|
||||
@@ -805,6 +813,7 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
"choices": [{"index": 0, "delta": {"content": item}, "finish_reason": None}],
|
||||
}
|
||||
await response.write(f"data: {json.dumps(content_chunk)}\n\n".encode())
|
||||
return time.monotonic()
|
||||
|
||||
# Stream content chunks as they arrive from the agent
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -819,16 +828,19 @@ class APIServerAdapter(BasePlatformAdapter):
|
||||
delta = stream_q.get_nowait()
|
||||
if delta is None:
|
||||
break
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
except _q.Empty:
|
||||
break
|
||||
break
|
||||
if time.monotonic() - last_activity >= CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS:
|
||||
await response.write(b": keepalive\n\n")
|
||||
last_activity = time.monotonic()
|
||||
continue
|
||||
|
||||
if delta is None: # End of stream sentinel
|
||||
break
|
||||
|
||||
await _emit(delta)
|
||||
last_activity = await _emit(delta)
|
||||
|
||||
# Get usage from completed agent
|
||||
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
@@ -409,11 +409,50 @@ class TestChatCompletionsEndpoint:
|
||||
)
|
||||
assert resp.status == 200
|
||||
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
||||
assert resp.headers.get("X-Accel-Buffering") == "no"
|
||||
body = await resp.text()
|
||||
assert "data: " in body
|
||||
assert "[DONE]" in body
|
||||
assert "Hello!" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_sends_keepalive_during_quiet_tool_gap(self, adapter):
|
||||
"""Idle SSE streams should send keepalive comments while tools run silently."""
|
||||
import asyncio
|
||||
import gateway.platforms.api_server as api_server_mod
|
||||
|
||||
app = _create_app(adapter)
|
||||
async with TestClient(TestServer(app)) as cli:
|
||||
async def _mock_run_agent(**kwargs):
|
||||
cb = kwargs.get("stream_delta_callback")
|
||||
if cb:
|
||||
cb("Working")
|
||||
await asyncio.sleep(0.65)
|
||||
cb("...done")
|
||||
return (
|
||||
{"final_response": "Working...done", "messages": [], "api_calls": 1},
|
||||
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(api_server_mod, "CHAT_COMPLETIONS_SSE_KEEPALIVE_SECONDS", 0.01),
|
||||
patch.object(adapter, "_run_agent", side_effect=_mock_run_agent),
|
||||
):
|
||||
resp = await cli.post(
|
||||
"/v1/chat/completions",
|
||||
json={
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": "do the thing"}],
|
||||
"stream": True,
|
||||
},
|
||||
)
|
||||
assert resp.status == 200
|
||||
body = await resp.text()
|
||||
assert ": keepalive" in body
|
||||
assert "Working" in body
|
||||
assert "...done" in body
|
||||
assert "[DONE]" in body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
|
||||
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.
|
||||
|
||||
Reference in New Issue
Block a user