Files
hermes-agent/tests/gateway/test_sse_agent_cancel.py
Teknium f57ebf52e9 fix(api-server): cancel orphaned agent + true interrupt on SSE disconnect (salvage #3399) (#3427)
Salvage of #3399 by @binhnt92 with true agent interruption added on top.

When a streaming /v1/chat/completions client disconnects mid-stream, the agent is now interrupted via agent.interrupt() so it stops making LLM API calls, and the asyncio task wrapper is cancelled.

Closes #3399.
2026-03-27 11:33:19 -07:00

281 lines
9.7 KiB
Python

"""Tests for SSE client disconnect → agent task cancellation.
When a streaming /v1/chat/completions client disconnects mid-stream
(network drop, browser tab close), the agent is interrupted via
agent.interrupt() so it stops making LLM API calls, and the asyncio
task wrapper is cancelled.
"""
import asyncio
import json
import queue
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_adapter():
"""Build a minimal APIServerAdapter with mocked internals."""
from gateway.platforms.api_server import APIServerAdapter
from gateway.config import PlatformConfig
config = PlatformConfig(enabled=True, token="test-key")
adapter = APIServerAdapter(config)
return adapter
def _make_request():
"""Build a mock aiohttp request."""
req = MagicMock()
req.headers = {}
return req
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestSSEAgentCancelOnDisconnect:
"""gateway/platforms/api_server.py — _write_sse_chat_completion()"""
def test_agent_task_cancelled_on_client_disconnect(self):
"""When response.write raises ConnectionResetError (client dropped),
the agent task must be cancelled."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello ") # Some data already queued
# Agent task that runs forever (simulates a long LLM call)
agent_done = asyncio.Event()
async def fake_agent():
await agent_done.wait()
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
# Mock response that raises ConnectionResetError on second write
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("client disconnected")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch.object(type(adapter), '_write_sse_chat_completion',
adapter._write_sse_chat_completion):
# Patch StreamResponse creation
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-123", "gpt-4", 1234567890,
stream_q, agent_task,
)
# The critical assertion: agent_task must be cancelled
assert agent_task.cancelled() or agent_task.done()
# Clean up
agent_done.set()
asyncio.run(run())
def test_agent_task_not_cancelled_on_normal_completion(self):
"""On normal stream completion, agent task should NOT be cancelled."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello")
stream_q.put(None) # End-of-stream sentinel
async def fake_agent():
return {"final_response": "done"}, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
await asyncio.sleep(0) # Let agent complete
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock()
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-456", "gpt-4", 1234567890,
stream_q, agent_task,
)
# Agent should have completed normally, not been cancelled
assert agent_task.done()
assert not agent_task.cancelled()
asyncio.run(run())
def test_broken_pipe_also_cancels_agent(self):
"""BrokenPipeError (another disconnect variant) also cancels the task."""
adapter = _make_adapter()
stream_q = queue.Queue()
async def fake_agent():
await asyncio.sleep(999) # Never completes
return {}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock(side_effect=BrokenPipeError("pipe broken"))
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-789", "gpt-4", 1234567890,
stream_q, agent_task,
)
assert agent_task.cancelled() or agent_task.done()
asyncio.run(run())
def test_already_done_task_not_cancelled_on_disconnect(self):
"""If agent already finished before disconnect, don't try to cancel."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("data")
async def fake_agent():
return {"final_response": "done"}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
await asyncio.sleep(0) # Let agent complete
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("late disconnect")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-done", "gpt-4", 1234567890,
stream_q, agent_task,
)
# Task was already done — should not be cancelled
assert agent_task.done()
assert not agent_task.cancelled()
asyncio.run(run())
def test_agent_interrupt_called_on_disconnect(self):
"""When the client disconnects, agent.interrupt() must be called
so the agent thread stops making LLM API calls."""
adapter = _make_adapter()
stream_q = queue.Queue()
stream_q.put("hello ")
agent_done = asyncio.Event()
async def fake_agent():
await agent_done.wait()
return {"final_response": "done"}, {}
# Mock agent with an interrupt method
mock_agent = MagicMock()
mock_agent.interrupt = MagicMock()
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
agent_ref = [mock_agent]
mock_response = AsyncMock(spec=web.StreamResponse)
call_count = 0
async def write_side_effect(data):
nonlocal call_count
call_count += 1
if call_count >= 2:
raise ConnectionResetError("client disconnected")
mock_response.write = AsyncMock(side_effect=write_side_effect)
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-int", "gpt-4", 1234567890,
stream_q, agent_task, agent_ref,
)
# agent.interrupt() must have been called
mock_agent.interrupt.assert_called_once_with("SSE client disconnected")
# Clean up
agent_done.set()
asyncio.run(run())
def test_agent_ref_none_still_cancels_task(self):
"""When agent_ref is not provided (None), the task is still cancelled
on disconnect — just without the interrupt() call."""
adapter = _make_adapter()
stream_q = queue.Queue()
async def fake_agent():
await asyncio.sleep(999)
return {}, {}
async def run():
from aiohttp import web
agent_task = asyncio.ensure_future(fake_agent())
mock_response = AsyncMock(spec=web.StreamResponse)
mock_response.write = AsyncMock(side_effect=BrokenPipeError("gone"))
mock_response.prepare = AsyncMock()
with patch("gateway.platforms.api_server.web.StreamResponse",
return_value=mock_response):
# No agent_ref passed — should still handle disconnect cleanly
await adapter._write_sse_chat_completion(
_make_request(), "cmpl-noref", "gpt-4", 1234567890,
stream_q, agent_task,
)
assert agent_task.cancelled() or agent_task.done()
asyncio.run(run())