* fix(run_agent): ensure _fire_first_delta() is called for tool generation events Added calls to _fire_first_delta() in the AIAgent class to improve the handling of tool generation events, ensuring timely notifications during the processing of function calls and tool usage. * fix(run_agent): improve timeout handling for chat completions Enhanced the timeout configuration for chat completions in the AIAgent class by introducing customizable connection, read, and write timeouts using environment variables. This ensures more robust handling of API requests during streaming operations. * fix(run_agent): reduce default stream read timeout for chat completions Updated the default stream read timeout from 120 seconds to 60 seconds in the AIAgent class, enhancing the timeout configuration for chat completions. This change aims to improve responsiveness during streaming operations. * fix(run_agent): enhance streaming error handling and retry logic Improved the error handling and retry mechanism for streaming requests in the AIAgent class. Introduced a configurable maximum number of stream retries and refined the handling of transient network errors, allowing for retries with fresh connections. Non-transient errors now trigger a fallback to non-streaming only when appropriate, ensuring better resilience during API interactions. * fix(api_server): streaming breaks when agent makes tool calls The agent fires stream_delta_callback(None) to signal the CLI display to close its response box before tool execution begins. The API server's _on_delta callback was forwarding this None directly into the SSE queue, where the SSE writer treats it as end-of-stream and terminates the HTTP response prematurely. After tool calls complete, the agent streams the final answer through the same callback, but the SSE response was already closed — so Open WebUI (and similar frontends) never received the actual answer. Fix: filter out None in _on_delta so the SSE stream stays open. The SSE loop already detects completion via agent_task.done(), which handles stream termination correctly without needing the None sentinel. Reported by Rohit Paul on X.
1440 lines
59 KiB
Python
1440 lines
59 KiB
Python
"""
|
|
Tests for the OpenAI-compatible API server gateway adapter.
|
|
|
|
Tests cover:
|
|
- Chat Completions endpoint (request parsing, response format)
|
|
- Responses API endpoint (request parsing, response format)
|
|
- previous_response_id chaining (store/retrieve)
|
|
- Auth (valid key, invalid key, no key configured)
|
|
- /v1/models endpoint
|
|
- /health endpoint
|
|
- System prompt extraction
|
|
- Error handling (invalid JSON, missing fields)
|
|
"""
|
|
|
|
import json
|
|
import time
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from aiohttp import web
|
|
from aiohttp.test_utils import AioHTTPTestCase, TestClient, TestServer
|
|
|
|
from gateway.config import GatewayConfig, Platform, PlatformConfig
|
|
from gateway.platforms.api_server import (
|
|
APIServerAdapter,
|
|
ResponseStore,
|
|
_CORS_HEADERS,
|
|
check_api_server_requirements,
|
|
cors_middleware,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# check_api_server_requirements
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCheckRequirements:
|
|
def test_returns_true_when_aiohttp_available(self):
|
|
assert check_api_server_requirements() is True
|
|
|
|
@patch("gateway.platforms.api_server.AIOHTTP_AVAILABLE", False)
|
|
def test_returns_false_without_aiohttp(self):
|
|
assert check_api_server_requirements() is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ResponseStore
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResponseStore:
|
|
def test_put_and_get(self):
|
|
store = ResponseStore(max_size=10)
|
|
store.put("resp_1", {"output": "hello"})
|
|
assert store.get("resp_1") == {"output": "hello"}
|
|
|
|
def test_get_missing_returns_none(self):
|
|
store = ResponseStore(max_size=10)
|
|
assert store.get("resp_missing") is None
|
|
|
|
def test_lru_eviction(self):
|
|
store = ResponseStore(max_size=3)
|
|
store.put("resp_1", {"output": "one"})
|
|
store.put("resp_2", {"output": "two"})
|
|
store.put("resp_3", {"output": "three"})
|
|
# Adding a 4th should evict resp_1
|
|
store.put("resp_4", {"output": "four"})
|
|
assert store.get("resp_1") is None
|
|
assert store.get("resp_2") is not None
|
|
assert len(store) == 3
|
|
|
|
def test_access_refreshes_lru(self):
|
|
store = ResponseStore(max_size=3)
|
|
store.put("resp_1", {"output": "one"})
|
|
store.put("resp_2", {"output": "two"})
|
|
store.put("resp_3", {"output": "three"})
|
|
# Access resp_1 to move it to end
|
|
store.get("resp_1")
|
|
# Now resp_2 is the oldest — adding a 4th should evict resp_2
|
|
store.put("resp_4", {"output": "four"})
|
|
assert store.get("resp_2") is None
|
|
assert store.get("resp_1") is not None
|
|
|
|
def test_update_existing_key(self):
|
|
store = ResponseStore(max_size=10)
|
|
store.put("resp_1", {"output": "v1"})
|
|
store.put("resp_1", {"output": "v2"})
|
|
assert store.get("resp_1") == {"output": "v2"}
|
|
assert len(store) == 1
|
|
|
|
def test_delete_existing(self):
|
|
store = ResponseStore(max_size=10)
|
|
store.put("resp_1", {"output": "hello"})
|
|
assert store.delete("resp_1") is True
|
|
assert store.get("resp_1") is None
|
|
assert len(store) == 0
|
|
|
|
def test_delete_missing(self):
|
|
store = ResponseStore(max_size=10)
|
|
assert store.delete("resp_missing") is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Adapter initialization
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAdapterInit:
|
|
def test_default_config(self):
|
|
config = PlatformConfig(enabled=True)
|
|
adapter = APIServerAdapter(config)
|
|
assert adapter._host == "127.0.0.1"
|
|
assert adapter._port == 8642
|
|
assert adapter._api_key == ""
|
|
assert adapter.platform == Platform.API_SERVER
|
|
|
|
def test_custom_config_from_extra(self):
|
|
config = PlatformConfig(
|
|
enabled=True,
|
|
extra={
|
|
"host": "0.0.0.0",
|
|
"port": 9999,
|
|
"key": "sk-test",
|
|
"cors_origins": ["http://localhost:3000"],
|
|
},
|
|
)
|
|
adapter = APIServerAdapter(config)
|
|
assert adapter._host == "0.0.0.0"
|
|
assert adapter._port == 9999
|
|
assert adapter._api_key == "sk-test"
|
|
assert adapter._cors_origins == ("http://localhost:3000",)
|
|
|
|
def test_config_from_env(self, monkeypatch):
|
|
monkeypatch.setenv("API_SERVER_HOST", "10.0.0.1")
|
|
monkeypatch.setenv("API_SERVER_PORT", "7777")
|
|
monkeypatch.setenv("API_SERVER_KEY", "sk-env")
|
|
monkeypatch.setenv("API_SERVER_CORS_ORIGINS", "http://localhost:3000, http://127.0.0.1:3000")
|
|
config = PlatformConfig(enabled=True)
|
|
adapter = APIServerAdapter(config)
|
|
assert adapter._host == "10.0.0.1"
|
|
assert adapter._port == 7777
|
|
assert adapter._api_key == "sk-env"
|
|
assert adapter._cors_origins == (
|
|
"http://localhost:3000",
|
|
"http://127.0.0.1:3000",
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth checking
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestAuth:
|
|
def test_no_key_configured_allows_all(self):
|
|
config = PlatformConfig(enabled=True)
|
|
adapter = APIServerAdapter(config)
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {}
|
|
assert adapter._check_auth(mock_request) is None
|
|
|
|
def test_valid_key_passes(self):
|
|
config = PlatformConfig(enabled=True, extra={"key": "sk-test123"})
|
|
adapter = APIServerAdapter(config)
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {"Authorization": "Bearer sk-test123"}
|
|
assert adapter._check_auth(mock_request) is None
|
|
|
|
def test_invalid_key_returns_401(self):
|
|
config = PlatformConfig(enabled=True, extra={"key": "sk-test123"})
|
|
adapter = APIServerAdapter(config)
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {"Authorization": "Bearer wrong-key"}
|
|
result = adapter._check_auth(mock_request)
|
|
assert result is not None
|
|
assert result.status == 401
|
|
|
|
def test_missing_auth_header_returns_401(self):
|
|
config = PlatformConfig(enabled=True, extra={"key": "sk-test123"})
|
|
adapter = APIServerAdapter(config)
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {}
|
|
result = adapter._check_auth(mock_request)
|
|
assert result is not None
|
|
assert result.status == 401
|
|
|
|
def test_malformed_auth_header_returns_401(self):
|
|
config = PlatformConfig(enabled=True, extra={"key": "sk-test123"})
|
|
adapter = APIServerAdapter(config)
|
|
mock_request = MagicMock()
|
|
mock_request.headers = {"Authorization": "Basic dXNlcjpwYXNz"}
|
|
result = adapter._check_auth(mock_request)
|
|
assert result is not None
|
|
assert result.status == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers for HTTP tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_adapter(api_key: str = "", cors_origins=None) -> APIServerAdapter:
|
|
"""Create an adapter with optional API key."""
|
|
extra = {}
|
|
if api_key:
|
|
extra["key"] = api_key
|
|
if cors_origins is not None:
|
|
extra["cors_origins"] = cors_origins
|
|
config = PlatformConfig(enabled=True, extra=extra)
|
|
return APIServerAdapter(config)
|
|
|
|
|
|
def _create_app(adapter: APIServerAdapter) -> web.Application:
|
|
"""Create the aiohttp app from the adapter (without starting the full server)."""
|
|
app = web.Application(middlewares=[cors_middleware])
|
|
app["api_server_adapter"] = adapter
|
|
app.router.add_get("/health", adapter._handle_health)
|
|
app.router.add_get("/v1/models", adapter._handle_models)
|
|
app.router.add_post("/v1/chat/completions", adapter._handle_chat_completions)
|
|
app.router.add_post("/v1/responses", adapter._handle_responses)
|
|
app.router.add_get("/v1/responses/{response_id}", adapter._handle_get_response)
|
|
app.router.add_delete("/v1/responses/{response_id}", adapter._handle_delete_response)
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def adapter():
|
|
return _make_adapter()
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_adapter():
|
|
return _make_adapter(api_key="sk-secret")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# /health endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHealthEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_health_returns_ok(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/health")
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["status"] == "ok"
|
|
assert data["platform"] == "hermes-agent"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# /v1/models endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestModelsEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_models_returns_hermes_agent(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/models")
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["object"] == "list"
|
|
assert len(data["data"]) == 1
|
|
assert data["data"][0]["id"] == "hermes-agent"
|
|
assert data["data"][0]["owned_by"] == "hermes"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_models_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/models")
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_models_with_valid_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get(
|
|
"/v1/models",
|
|
headers={"Authorization": "Bearer sk-secret"},
|
|
)
|
|
assert resp.status == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# /v1/chat/completions endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestChatCompletionsEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
data="not json",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
assert resp.status == 400
|
|
data = await resp.json()
|
|
assert "Invalid JSON" in data["error"]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_messages_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/chat/completions", json={"model": "test"})
|
|
assert resp.status == 400
|
|
data = await resp.json()
|
|
assert "messages" in data["error"]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_messages_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/chat/completions", json={"model": "test", "messages": []})
|
|
assert resp.status == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_true_returns_sse(self, adapter):
|
|
"""stream=true returns SSE format with the full response."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
async def _mock_run_agent(**kwargs):
|
|
# Simulate streaming: invoke stream_delta_callback with tokens
|
|
cb = kwargs.get("stream_delta_callback")
|
|
if cb:
|
|
cb("Hello!")
|
|
cb(None) # End signal
|
|
return (
|
|
{"final_response": "Hello!", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
|
|
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent) as mock_run:
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "test",
|
|
"messages": [{"role": "user", "content": "hi"}],
|
|
"stream": True,
|
|
},
|
|
)
|
|
assert resp.status == 200
|
|
assert "text/event-stream" in resp.headers.get("Content-Type", "")
|
|
body = await resp.text()
|
|
assert "data: " in body
|
|
assert "[DONE]" in body
|
|
assert "Hello!" in body
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_stream_survives_tool_call_none_sentinel(self, adapter):
|
|
"""stream_delta_callback(None) mid-stream (tool calls) must NOT kill the SSE stream.
|
|
|
|
The agent fires stream_delta_callback(None) to tell the CLI display to
|
|
close its response box before executing tool calls. The API server's
|
|
_on_delta must filter this out so the SSE response stays open and the
|
|
final answer (streamed after tool execution) reaches the client.
|
|
"""
|
|
import asyncio
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
async def _mock_run_agent(**kwargs):
|
|
cb = kwargs.get("stream_delta_callback")
|
|
if cb:
|
|
# Simulate: agent streams partial text, then fires None
|
|
# (tool call box-close signal), then streams the final answer
|
|
cb("Thinking")
|
|
cb(None) # mid-stream None from tool calls
|
|
await asyncio.sleep(0.05) # simulate tool execution delay
|
|
cb(" about it...")
|
|
cb(None) # another None (possible second tool round)
|
|
await asyncio.sleep(0.05)
|
|
cb(" The answer is 42.")
|
|
return (
|
|
{"final_response": "Thinking about it... The answer is 42.", "messages": [], "api_calls": 3},
|
|
{"input_tokens": 20, "output_tokens": 15, "total_tokens": 35},
|
|
)
|
|
|
|
with patch.object(adapter, "_run_agent", side_effect=_mock_run_agent):
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "test",
|
|
"messages": [{"role": "user", "content": "What is the answer?"}],
|
|
"stream": True,
|
|
},
|
|
)
|
|
assert resp.status == 200
|
|
body = await resp.text()
|
|
assert "[DONE]" in body
|
|
# The final answer text must appear in the SSE stream
|
|
assert "The answer is 42." in body
|
|
# All partial text must be present too
|
|
assert "Thinking" in body
|
|
assert " about it..." in body
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_user_message_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "test",
|
|
"messages": [{"role": "system", "content": "You are helpful."}],
|
|
},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_completion(self, adapter):
|
|
"""Test a successful chat completion with mocked agent."""
|
|
mock_result = {
|
|
"final_response": "Hello! How can I help you today?",
|
|
"messages": [],
|
|
"api_calls": 1,
|
|
}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["object"] == "chat.completion"
|
|
assert data["id"].startswith("chatcmpl-")
|
|
assert data["model"] == "hermes-agent"
|
|
assert len(data["choices"]) == 1
|
|
assert data["choices"][0]["message"]["role"] == "assistant"
|
|
assert data["choices"][0]["message"]["content"] == "Hello! How can I help you today?"
|
|
assert data["choices"][0]["finish_reason"] == "stop"
|
|
assert "usage" in data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_prompt_extracted(self, adapter):
|
|
"""System messages from the client are passed as ephemeral_system_prompt."""
|
|
mock_result = {
|
|
"final_response": "I am a pirate! Arrr!",
|
|
"messages": [],
|
|
"api_calls": 1,
|
|
}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [
|
|
{"role": "system", "content": "You are a pirate."},
|
|
{"role": "user", "content": "Hello"},
|
|
],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
# Check that _run_agent was called with the system prompt
|
|
call_kwargs = mock_run.call_args
|
|
assert call_kwargs.kwargs.get("ephemeral_system_prompt") == "You are a pirate."
|
|
assert call_kwargs.kwargs.get("user_message") == "Hello"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_history_passed(self, adapter):
|
|
"""Previous user/assistant messages become conversation_history."""
|
|
mock_result = {"final_response": "3", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [
|
|
{"role": "user", "content": "1+1=?"},
|
|
{"role": "assistant", "content": "2"},
|
|
{"role": "user", "content": "Now add 1 more"},
|
|
],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
assert call_kwargs["user_message"] == "Now add 1 more"
|
|
assert len(call_kwargs["conversation_history"]) == 2
|
|
assert call_kwargs["conversation_history"][0] == {"role": "user", "content": "1+1=?"}
|
|
assert call_kwargs["conversation_history"][1] == {"role": "assistant", "content": "2"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_error_returns_500(self, adapter):
|
|
"""Agent exception returns 500."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.side_effect = RuntimeError("Provider failed")
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [{"role": "user", "content": "Hello"}],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 500
|
|
data = await resp.json()
|
|
assert "Provider failed" in data["error"]["message"]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# /v1/responses endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestResponsesEndpoint:
|
|
@pytest.mark.asyncio
|
|
async def test_missing_input_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/responses", json={"model": "test"})
|
|
assert resp.status == 400
|
|
data = await resp.json()
|
|
assert "input" in data["error"]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_json_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
data="not json",
|
|
headers={"Content-Type": "application/json"},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_response_with_string_input(self, adapter):
|
|
"""String input is wrapped in a user message."""
|
|
mock_result = {
|
|
"final_response": "Paris is the capital of France.",
|
|
"messages": [],
|
|
"api_calls": 1,
|
|
}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "What is the capital of France?",
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["object"] == "response"
|
|
assert data["id"].startswith("resp_")
|
|
assert data["status"] == "completed"
|
|
assert len(data["output"]) == 1
|
|
assert data["output"][0]["type"] == "message"
|
|
assert data["output"][0]["content"][0]["type"] == "output_text"
|
|
assert data["output"][0]["content"][0]["text"] == "Paris is the capital of France."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_successful_response_with_array_input(self, adapter):
|
|
"""Array input with role/content objects."""
|
|
mock_result = {"final_response": "Done", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "user", "content": "What is 2+2?"},
|
|
],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
# Last message is user_message, rest are history
|
|
assert call_kwargs["user_message"] == "What is 2+2?"
|
|
assert len(call_kwargs["conversation_history"]) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_instructions_as_ephemeral_prompt(self, adapter):
|
|
"""The instructions field maps to ephemeral_system_prompt."""
|
|
mock_result = {"final_response": "Ahoy!", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "Hello",
|
|
"instructions": "Talk like a pirate.",
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
assert call_kwargs["ephemeral_system_prompt"] == "Talk like a pirate."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_previous_response_id_chaining(self, adapter):
|
|
"""Test that responses can be chained via previous_response_id."""
|
|
mock_result_1 = {
|
|
"final_response": "2",
|
|
"messages": [{"role": "assistant", "content": "2"}],
|
|
"api_calls": 1,
|
|
}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
# First request
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result_1, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp1 = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "What is 1+1?"},
|
|
)
|
|
|
|
assert resp1.status == 200
|
|
data1 = await resp1.json()
|
|
response_id = data1["id"]
|
|
|
|
# Second request chaining from the first
|
|
mock_result_2 = {
|
|
"final_response": "3",
|
|
"messages": [{"role": "assistant", "content": "3"}],
|
|
"api_calls": 1,
|
|
}
|
|
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result_2, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp2 = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "Now add 1 more",
|
|
"previous_response_id": response_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status == 200
|
|
# The conversation_history should contain the full history from the first response
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
assert len(call_kwargs["conversation_history"]) > 0
|
|
assert call_kwargs["user_message"] == "Now add 1 more"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_previous_response_id_returns_404(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "follow up",
|
|
"previous_response_id": "resp_nonexistent",
|
|
},
|
|
)
|
|
assert resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_store_false_does_not_store(self, adapter):
|
|
"""When store=false, the response is NOT stored."""
|
|
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "Hello",
|
|
"store": False,
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
# The response has an ID but it shouldn't be retrievable
|
|
assert adapter._response_store.get(data["id"]) is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_instructions_inherited_from_previous(self, adapter):
|
|
"""If no instructions provided, carry forward from previous response."""
|
|
mock_result = {"final_response": "Ahoy!", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
# First request with instructions
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp1 = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "Hello",
|
|
"instructions": "Be a pirate",
|
|
},
|
|
)
|
|
|
|
data1 = await resp1.json()
|
|
resp_id = data1["id"]
|
|
|
|
# Second request without instructions
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp2 = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "Tell me more",
|
|
"previous_response_id": resp_id,
|
|
},
|
|
)
|
|
|
|
assert resp2.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
assert call_kwargs["ephemeral_system_prompt"] == "Be a pirate"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_agent_error_returns_500(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.side_effect = RuntimeError("Boom")
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "Hello"},
|
|
)
|
|
|
|
assert resp.status == 500
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_input_type_returns_400(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": 42},
|
|
)
|
|
assert resp.status == 400
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth on endpoints
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEndpointAuth:
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completions_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={"model": "test", "messages": [{"role": "user", "content": "hi"}]},
|
|
)
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_responses_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "test", "input": "hi"},
|
|
)
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_models_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/models")
|
|
assert resp.status == 401
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_does_not_require_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/health")
|
|
assert resp.status == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Config integration
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConfigIntegration:
|
|
def test_platform_enum_has_api_server(self):
|
|
assert Platform.API_SERVER.value == "api_server"
|
|
|
|
def test_env_override_enables_api_server(self, monkeypatch):
|
|
monkeypatch.setenv("API_SERVER_ENABLED", "true")
|
|
from gateway.config import load_gateway_config
|
|
config = load_gateway_config()
|
|
assert Platform.API_SERVER in config.platforms
|
|
assert config.platforms[Platform.API_SERVER].enabled is True
|
|
|
|
def test_env_override_with_key(self, monkeypatch):
|
|
monkeypatch.setenv("API_SERVER_KEY", "sk-mykey")
|
|
from gateway.config import load_gateway_config
|
|
config = load_gateway_config()
|
|
assert Platform.API_SERVER in config.platforms
|
|
assert config.platforms[Platform.API_SERVER].extra.get("key") == "sk-mykey"
|
|
|
|
def test_env_override_port_and_host(self, monkeypatch):
|
|
monkeypatch.setenv("API_SERVER_ENABLED", "true")
|
|
monkeypatch.setenv("API_SERVER_PORT", "9999")
|
|
monkeypatch.setenv("API_SERVER_HOST", "0.0.0.0")
|
|
from gateway.config import load_gateway_config
|
|
config = load_gateway_config()
|
|
assert config.platforms[Platform.API_SERVER].extra.get("port") == 9999
|
|
assert config.platforms[Platform.API_SERVER].extra.get("host") == "0.0.0.0"
|
|
|
|
def test_env_override_cors_origins(self, monkeypatch):
|
|
monkeypatch.setenv("API_SERVER_ENABLED", "true")
|
|
monkeypatch.setenv(
|
|
"API_SERVER_CORS_ORIGINS",
|
|
"http://localhost:3000, http://127.0.0.1:3000",
|
|
)
|
|
from gateway.config import load_gateway_config
|
|
config = load_gateway_config()
|
|
assert config.platforms[Platform.API_SERVER].extra.get("cors_origins") == [
|
|
"http://localhost:3000",
|
|
"http://127.0.0.1:3000",
|
|
]
|
|
|
|
def test_api_server_in_connected_platforms(self):
|
|
config = GatewayConfig()
|
|
config.platforms[Platform.API_SERVER] = PlatformConfig(enabled=True)
|
|
connected = config.get_connected_platforms()
|
|
assert Platform.API_SERVER in connected
|
|
|
|
def test_api_server_not_in_connected_when_disabled(self):
|
|
config = GatewayConfig()
|
|
config.platforms[Platform.API_SERVER] = PlatformConfig(enabled=False)
|
|
connected = config.get_connected_platforms()
|
|
assert Platform.API_SERVER not in connected
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Multiple system messages
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMultipleSystemMessages:
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_system_messages_concatenated(self, adapter):
|
|
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [
|
|
{"role": "system", "content": "You are helpful."},
|
|
{"role": "system", "content": "Be concise."},
|
|
{"role": "user", "content": "Hello"},
|
|
],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
prompt = call_kwargs["ephemeral_system_prompt"]
|
|
assert "You are helpful." in prompt
|
|
assert "Be concise." in prompt
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# send() method (not used but required by base)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSendMethod:
|
|
@pytest.mark.asyncio
|
|
async def test_send_returns_not_supported(self):
|
|
config = PlatformConfig(enabled=True)
|
|
adapter = APIServerAdapter(config)
|
|
result = await adapter.send("chat1", "hello")
|
|
assert result.success is False
|
|
assert "HTTP request/response" in result.error
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /v1/responses/{response_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_get_stored_response(self, adapter):
|
|
"""GET returns a previously stored response."""
|
|
mock_result = {"final_response": "Hello!", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
# Create a response first
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "Hi"},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
response_id = data["id"]
|
|
|
|
# Now GET it
|
|
resp2 = await cli.get(f"/v1/responses/{response_id}")
|
|
assert resp2.status == 200
|
|
data2 = await resp2.json()
|
|
assert data2["id"] == response_id
|
|
assert data2["object"] == "response"
|
|
assert data2["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_not_found(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/responses/resp_nonexistent")
|
|
assert resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/v1/responses/resp_any")
|
|
assert resp.status == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DELETE /v1/responses/{response_id}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDeleteResponse:
|
|
@pytest.mark.asyncio
|
|
async def test_delete_stored_response(self, adapter):
|
|
"""DELETE removes a stored response and returns confirmation."""
|
|
mock_result = {"final_response": "Hello!", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "Hi"},
|
|
)
|
|
|
|
data = await resp.json()
|
|
response_id = data["id"]
|
|
|
|
# Delete it
|
|
resp2 = await cli.delete(f"/v1/responses/{response_id}")
|
|
assert resp2.status == 200
|
|
data2 = await resp2.json()
|
|
assert data2["id"] == response_id
|
|
assert data2["object"] == "response"
|
|
assert data2["deleted"] is True
|
|
|
|
# Verify it's gone
|
|
resp3 = await cli.get(f"/v1/responses/{response_id}")
|
|
assert resp3.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_not_found(self, adapter):
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.delete("/v1/responses/resp_nonexistent")
|
|
assert resp.status == 404
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_requires_auth(self, auth_adapter):
|
|
app = _create_app(auth_adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.delete("/v1/responses/resp_any")
|
|
assert resp.status == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tool calls in output
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestToolCallsInOutput:
|
|
@pytest.mark.asyncio
|
|
async def test_tool_calls_in_output(self, adapter):
|
|
"""When agent returns tool calls, they appear as function_call items."""
|
|
mock_result = {
|
|
"final_response": "The result is 42.",
|
|
"messages": [
|
|
{
|
|
"role": "assistant",
|
|
"content": None,
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_abc123",
|
|
"function": {
|
|
"name": "calculator",
|
|
"arguments": '{"expression": "6*7"}',
|
|
},
|
|
}
|
|
],
|
|
},
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": "call_abc123",
|
|
"content": "42",
|
|
},
|
|
{
|
|
"role": "assistant",
|
|
"content": "The result is 42.",
|
|
},
|
|
],
|
|
"api_calls": 2,
|
|
}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "What is 6*7?"},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
output = data["output"]
|
|
|
|
# Should have: function_call, function_call_output, message
|
|
assert len(output) == 3
|
|
assert output[0]["type"] == "function_call"
|
|
assert output[0]["name"] == "calculator"
|
|
assert output[0]["arguments"] == '{"expression": "6*7"}'
|
|
assert output[0]["call_id"] == "call_abc123"
|
|
assert output[1]["type"] == "function_call_output"
|
|
assert output[1]["call_id"] == "call_abc123"
|
|
assert output[1]["output"] == "42"
|
|
assert output[2]["type"] == "message"
|
|
assert output[2]["content"][0]["text"] == "The result is 42."
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_tool_calls_still_works(self, adapter):
|
|
"""Without tool calls, output is just a message."""
|
|
mock_result = {"final_response": "Hello!", "messages": [], "api_calls": 1}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "Hello"},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert len(data["output"]) == 1
|
|
assert data["output"][0]["type"] == "message"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Usage / token counting
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestUsageCounting:
|
|
@pytest.mark.asyncio
|
|
async def test_responses_usage(self, adapter):
|
|
"""Responses API returns real token counts."""
|
|
mock_result = {"final_response": "Done", "messages": [], "api_calls": 1}
|
|
usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, usage)
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={"model": "hermes-agent", "input": "Hi"},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["usage"]["input_tokens"] == 100
|
|
assert data["usage"]["output_tokens"] == 50
|
|
assert data["usage"]["total_tokens"] == 150
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_chat_completions_usage(self, adapter):
|
|
"""Chat completions returns real token counts."""
|
|
mock_result = {"final_response": "Done", "messages": [], "api_calls": 1}
|
|
usage = {"input_tokens": 200, "output_tokens": 80, "total_tokens": 280}
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, usage)
|
|
resp = await cli.post(
|
|
"/v1/chat/completions",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"messages": [{"role": "user", "content": "Hi"}],
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["usage"]["prompt_tokens"] == 200
|
|
assert data["usage"]["completion_tokens"] == 80
|
|
assert data["usage"]["total_tokens"] == 280
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Truncation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestTruncation:
|
|
@pytest.mark.asyncio
|
|
async def test_truncation_auto_limits_history(self, adapter):
|
|
"""With truncation=auto, history over 100 messages is trimmed."""
|
|
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
|
|
|
# Pre-seed a stored response with a long history
|
|
long_history = [{"role": "user", "content": f"msg {i}"} for i in range(150)]
|
|
adapter._response_store.put("resp_prev", {
|
|
"response": {"id": "resp_prev", "object": "response"},
|
|
"conversation_history": long_history,
|
|
"instructions": None,
|
|
})
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "follow up",
|
|
"previous_response_id": "resp_prev",
|
|
"truncation": "auto",
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
# History should be truncated to 100
|
|
assert len(call_kwargs["conversation_history"]) <= 100
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_truncation_keeps_full_history(self, adapter):
|
|
"""Without truncation=auto, long history is passed as-is."""
|
|
mock_result = {"final_response": "OK", "messages": [], "api_calls": 1}
|
|
|
|
long_history = [{"role": "user", "content": f"msg {i}"} for i in range(150)]
|
|
adapter._response_store.put("resp_prev2", {
|
|
"response": {"id": "resp_prev2", "object": "response"},
|
|
"conversation_history": long_history,
|
|
"instructions": None,
|
|
})
|
|
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (mock_result, {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0})
|
|
resp = await cli.post(
|
|
"/v1/responses",
|
|
json={
|
|
"model": "hermes-agent",
|
|
"input": "follow up",
|
|
"previous_response_id": "resp_prev2",
|
|
},
|
|
)
|
|
|
|
assert resp.status == 200
|
|
call_kwargs = mock_run.call_args.kwargs
|
|
assert len(call_kwargs["conversation_history"]) == 150
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CORS
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCORS:
|
|
def test_origin_allowed_for_non_browser_client(self, adapter):
|
|
assert adapter._origin_allowed("") is True
|
|
|
|
def test_origin_rejected_by_default(self, adapter):
|
|
assert adapter._origin_allowed("http://evil.example") is False
|
|
|
|
def test_origin_allowed_for_allowlist_match(self):
|
|
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
|
assert adapter._origin_allowed("http://localhost:3000") is True
|
|
|
|
def test_cors_headers_for_origin_disabled_by_default(self, adapter):
|
|
assert adapter._cors_headers_for_origin("http://localhost:3000") is None
|
|
|
|
def test_cors_headers_for_origin_matches_allowlist(self):
|
|
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
|
headers = adapter._cors_headers_for_origin("http://localhost:3000")
|
|
assert headers is not None
|
|
assert headers["Access-Control-Allow-Origin"] == "http://localhost:3000"
|
|
assert "POST" in headers["Access-Control-Allow-Methods"]
|
|
|
|
def test_cors_headers_for_origin_rejects_unknown_origin(self):
|
|
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
|
assert adapter._cors_headers_for_origin("http://evil.example") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cors_headers_not_present_by_default(self, adapter):
|
|
"""CORS is disabled unless explicitly configured."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/health")
|
|
assert resp.status == 200
|
|
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_browser_origin_rejected_by_default(self, adapter):
|
|
"""Browser-originated requests are rejected unless explicitly allowed."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/health", headers={"Origin": "http://evil.example"})
|
|
assert resp.status == 403
|
|
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cors_options_preflight_rejected_by_default(self, adapter):
|
|
"""Browser preflight is rejected unless CORS is explicitly configured."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.options(
|
|
"/v1/chat/completions",
|
|
headers={
|
|
"Origin": "http://evil.example",
|
|
"Access-Control-Request-Method": "POST",
|
|
},
|
|
)
|
|
assert resp.status == 403
|
|
assert resp.headers.get("Access-Control-Allow-Origin") is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cors_headers_present_for_allowed_origin(self):
|
|
"""Allowed origins receive explicit CORS headers."""
|
|
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.get("/health", headers={"Origin": "http://localhost:3000"})
|
|
assert resp.status == 200
|
|
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
|
assert "POST" in resp.headers.get("Access-Control-Allow-Methods", "")
|
|
assert "DELETE" in resp.headers.get("Access-Control-Allow-Methods", "")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cors_options_preflight_allowed_for_configured_origin(self):
|
|
"""Configured origins can complete browser preflight."""
|
|
adapter = _make_adapter(cors_origins=["http://localhost:3000"])
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.options(
|
|
"/v1/chat/completions",
|
|
headers={
|
|
"Origin": "http://localhost:3000",
|
|
"Access-Control-Request-Method": "POST",
|
|
"Access-Control-Request-Headers": "Authorization, Content-Type",
|
|
},
|
|
)
|
|
assert resp.status == 200
|
|
assert resp.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
|
assert "Authorization" in resp.headers.get("Access-Control-Allow-Headers", "")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Conversation parameter
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestConversationParameter:
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_creates_new(self, adapter):
|
|
"""First request with a conversation name works (new conversation)."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (
|
|
{"final_response": "Hello!", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
resp = await cli.post("/v1/responses", json={
|
|
"input": "hi",
|
|
"conversation": "my-chat",
|
|
})
|
|
assert resp.status == 200
|
|
data = await resp.json()
|
|
assert data["status"] == "completed"
|
|
# Conversation mapping should be set
|
|
assert adapter._response_store.get_conversation("my-chat") is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_chains_automatically(self, adapter):
|
|
"""Second request with same conversation name chains to first."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (
|
|
{"final_response": "First response", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
# First request
|
|
resp1 = await cli.post("/v1/responses", json={
|
|
"input": "hello",
|
|
"conversation": "test-conv",
|
|
})
|
|
assert resp1.status == 200
|
|
data1 = await resp1.json()
|
|
resp1_id = data1["id"]
|
|
|
|
# Second request — should chain
|
|
mock_run.return_value = (
|
|
{"final_response": "Second response", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
|
)
|
|
resp2 = await cli.post("/v1/responses", json={
|
|
"input": "follow up",
|
|
"conversation": "test-conv",
|
|
})
|
|
assert resp2.status == 200
|
|
|
|
# The second call should have received conversation history from the first
|
|
assert mock_run.call_count == 2
|
|
second_call_kwargs = mock_run.call_args_list[1]
|
|
history = second_call_kwargs.kwargs.get("conversation_history",
|
|
second_call_kwargs[1].get("conversation_history", []) if len(second_call_kwargs) > 1 else [])
|
|
# History should be non-empty (contains messages from first response)
|
|
assert len(history) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_and_previous_response_id_conflict(self, adapter):
|
|
"""Cannot use both conversation and previous_response_id."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
resp = await cli.post("/v1/responses", json={
|
|
"input": "hi",
|
|
"conversation": "my-chat",
|
|
"previous_response_id": "resp_abc123",
|
|
})
|
|
assert resp.status == 400
|
|
data = await resp.json()
|
|
assert "Cannot use both" in data["error"]["message"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_separate_conversations_are_isolated(self, adapter):
|
|
"""Different conversation names have independent histories."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (
|
|
{"final_response": "Response A", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
# Conversation A
|
|
await cli.post("/v1/responses", json={"input": "conv-a msg", "conversation": "conv-a"})
|
|
# Conversation B
|
|
mock_run.return_value = (
|
|
{"final_response": "Response B", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
await cli.post("/v1/responses", json={"input": "conv-b msg", "conversation": "conv-b"})
|
|
|
|
# They should have different response IDs in the mapping
|
|
assert adapter._response_store.get_conversation("conv-a") != adapter._response_store.get_conversation("conv-b")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conversation_store_false_no_mapping(self, adapter):
|
|
"""If store=false, conversation mapping is not updated."""
|
|
app = _create_app(adapter)
|
|
async with TestClient(TestServer(app)) as cli:
|
|
with patch.object(adapter, "_run_agent", new_callable=AsyncMock) as mock_run:
|
|
mock_run.return_value = (
|
|
{"final_response": "Ephemeral", "messages": [], "api_calls": 1},
|
|
{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
|
)
|
|
resp = await cli.post("/v1/responses", json={
|
|
"input": "hi",
|
|
"conversation": "ephemeral-chat",
|
|
"store": False,
|
|
})
|
|
assert resp.status == 200
|
|
# Conversation mapping should NOT be set since store=false
|
|
assert adapter._response_store.get_conversation("ephemeral-chat") is None
|