feat(mcp): add sampling support — server-initiated LLM requests (#753)
Add MCP sampling/createMessage capability via SamplingHandler class. Text-only sampling + tool use in sampling with governance (rate limits, model whitelist, token caps, tool loop limits). Per-server audit metrics. Based on concept from PR #366 by eren-karakus0. Restructured as class-based design with bug fixes and tests using real MCP SDK types. 50 new tests, 2600 total passing.
This commit is contained in:
@@ -1489,3 +1489,781 @@ class TestUtilityToolRegistration:
|
||||
assert entry.check_fn() is False
|
||||
|
||||
_servers.pop("chk", None)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# SamplingHandler tests
|
||||
# ===========================================================================
|
||||
|
||||
import math
|
||||
import time
|
||||
|
||||
from mcp.types import (
|
||||
CreateMessageResult,
|
||||
CreateMessageResultWithTools,
|
||||
ErrorData,
|
||||
SamplingCapability,
|
||||
SamplingToolsCapability,
|
||||
TextContent,
|
||||
ToolUseContent,
|
||||
)
|
||||
|
||||
from tools.mcp_tool import SamplingHandler, _safe_numeric
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for sampling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_sampling_params(
|
||||
messages=None,
|
||||
max_tokens=100,
|
||||
system_prompt=None,
|
||||
model_preferences=None,
|
||||
temperature=None,
|
||||
stop_sequences=None,
|
||||
tools=None,
|
||||
tool_choice=None,
|
||||
):
|
||||
"""Create a fake CreateMessageRequestParams using SimpleNamespace.
|
||||
|
||||
Each message must have a ``content_as_list`` attribute that mirrors
|
||||
the SDK helper so that ``_convert_messages`` works correctly.
|
||||
"""
|
||||
if messages is None:
|
||||
content = SimpleNamespace(text="Hello")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
messages = [msg]
|
||||
|
||||
params = SimpleNamespace(
|
||||
messages=messages,
|
||||
maxTokens=max_tokens,
|
||||
modelPreferences=model_preferences,
|
||||
temperature=temperature,
|
||||
stopSequences=stop_sequences,
|
||||
tools=tools,
|
||||
toolChoice=tool_choice,
|
||||
)
|
||||
if system_prompt is not None:
|
||||
params.systemPrompt = system_prompt
|
||||
return params
|
||||
|
||||
|
||||
def _make_llm_response(
|
||||
content="LLM response",
|
||||
model="test-model",
|
||||
finish_reason="stop",
|
||||
tool_calls=None,
|
||||
):
|
||||
"""Create a fake OpenAI chat completion response (text)."""
|
||||
message = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||
choice = SimpleNamespace(
|
||||
finish_reason=finish_reason,
|
||||
message=message,
|
||||
)
|
||||
usage = SimpleNamespace(total_tokens=42)
|
||||
return SimpleNamespace(choices=[choice], model=model, usage=usage)
|
||||
|
||||
|
||||
def _make_llm_tool_response(tool_calls_data=None, model="test-model"):
|
||||
"""Create a fake response with tool_calls.
|
||||
|
||||
``tool_calls_data``: list of (id, name, arguments_json) tuples.
|
||||
"""
|
||||
if tool_calls_data is None:
|
||||
tool_calls_data = [("call_1", "get_weather", '{"city": "London"}')]
|
||||
|
||||
tc_list = [
|
||||
SimpleNamespace(
|
||||
id=tc_id,
|
||||
function=SimpleNamespace(name=name, arguments=args),
|
||||
)
|
||||
for tc_id, name, args in tool_calls_data
|
||||
]
|
||||
return _make_llm_response(
|
||||
content=None,
|
||||
model=model,
|
||||
finish_reason="tool_calls",
|
||||
tool_calls=tc_list,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. _safe_numeric helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSafeNumeric:
|
||||
def test_int_passthrough(self):
|
||||
assert _safe_numeric(10, 5, int) == 10
|
||||
|
||||
def test_string_coercion(self):
|
||||
assert _safe_numeric("20", 5, int) == 20
|
||||
|
||||
def test_none_returns_default(self):
|
||||
assert _safe_numeric(None, 7, int) == 7
|
||||
|
||||
def test_inf_returns_default(self):
|
||||
assert _safe_numeric(float("inf"), 3.0, float) == 3.0
|
||||
|
||||
def test_nan_returns_default(self):
|
||||
assert _safe_numeric(float("nan"), 4.0, float) == 4.0
|
||||
|
||||
def test_below_minimum_clamps(self):
|
||||
assert _safe_numeric(-5, 10, int, minimum=1) == 1
|
||||
|
||||
def test_minimum_zero_allowed(self):
|
||||
assert _safe_numeric(0, 10, int, minimum=0) == 0
|
||||
|
||||
def test_non_numeric_string_returns_default(self):
|
||||
assert _safe_numeric("abc", 42, int) == 42
|
||||
|
||||
def test_float_coercion(self):
|
||||
assert _safe_numeric("3.5", 1.0, float) == 3.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. SamplingHandler initialization and config parsing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingHandlerInit:
|
||||
def test_defaults(self):
|
||||
h = SamplingHandler("srv", {})
|
||||
assert h.server_name == "srv"
|
||||
assert h.max_rpm == 10
|
||||
assert h.timeout == 30
|
||||
assert h.max_tokens_cap == 4096
|
||||
assert h.max_tool_rounds == 5
|
||||
assert h.model_override is None
|
||||
assert h.allowed_models == []
|
||||
assert h.metrics == {"requests": 0, "errors": 0, "tokens_used": 0, "tool_use_count": 0}
|
||||
|
||||
def test_custom_config(self):
|
||||
cfg = {
|
||||
"max_rpm": 20,
|
||||
"timeout": 60,
|
||||
"max_tokens_cap": 2048,
|
||||
"max_tool_rounds": 3,
|
||||
"model": "gpt-4o",
|
||||
"allowed_models": ["gpt-4o", "gpt-3.5-turbo"],
|
||||
"log_level": "debug",
|
||||
}
|
||||
h = SamplingHandler("custom", cfg)
|
||||
assert h.max_rpm == 20
|
||||
assert h.timeout == 60.0
|
||||
assert h.max_tokens_cap == 2048
|
||||
assert h.max_tool_rounds == 3
|
||||
assert h.model_override == "gpt-4o"
|
||||
assert h.allowed_models == ["gpt-4o", "gpt-3.5-turbo"]
|
||||
|
||||
def test_string_numeric_config_values(self):
|
||||
"""YAML sometimes delivers numeric values as strings."""
|
||||
cfg = {"max_rpm": "15", "timeout": "45.5", "max_tokens_cap": "1024"}
|
||||
h = SamplingHandler("s", cfg)
|
||||
assert h.max_rpm == 15
|
||||
assert h.timeout == 45.5
|
||||
assert h.max_tokens_cap == 1024
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Rate limiting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestRateLimit:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("rl", {"max_rpm": 3})
|
||||
|
||||
def test_allows_under_limit(self):
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
def test_rejects_over_limit(self):
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
assert self.handler._check_rate_limit() is False
|
||||
|
||||
def test_window_expiry(self):
|
||||
"""Old timestamps should be purged from the sliding window."""
|
||||
for _ in range(3):
|
||||
self.handler._check_rate_limit()
|
||||
# Simulate timestamps from 61 seconds ago
|
||||
self.handler._rate_timestamps[:] = [time.time() - 61] * 3
|
||||
assert self.handler._check_rate_limit() is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Model resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestResolveModel:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mr", {})
|
||||
|
||||
def test_no_preference_no_override(self):
|
||||
assert self.handler._resolve_model(None) is None
|
||||
|
||||
def test_config_override_wins(self):
|
||||
self.handler.model_override = "override-model"
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "override-model"
|
||||
|
||||
def test_hint_used_when_no_override(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name="hint-model")])
|
||||
assert self.handler._resolve_model(prefs) == "hint-model"
|
||||
|
||||
def test_empty_hints(self):
|
||||
prefs = SimpleNamespace(hints=[])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
def test_hint_without_name(self):
|
||||
prefs = SimpleNamespace(hints=[SimpleNamespace(name=None)])
|
||||
assert self.handler._resolve_model(prefs) is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Message conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestConvertMessages:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("mc", {})
|
||||
|
||||
def test_single_text_message(self):
|
||||
content = SimpleNamespace(text="Hello world")
|
||||
msg = SimpleNamespace(role="user", content=content, content_as_list=[content])
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"role": "user", "content": "Hello world"}
|
||||
|
||||
def test_image_message(self):
|
||||
text_block = SimpleNamespace(text="Look at this")
|
||||
img_block = SimpleNamespace(data="abc123", mimeType="image/png")
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[text_block, img_block],
|
||||
content_as_list=[text_block, img_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
parts = result[0]["content"]
|
||||
assert len(parts) == 2
|
||||
assert parts[0] == {"type": "text", "text": "Look at this"}
|
||||
assert parts[1]["type"] == "image_url"
|
||||
assert "data:image/png;base64,abc123" in parts[1]["image_url"]["url"]
|
||||
|
||||
def test_tool_result_message(self):
|
||||
inner = SimpleNamespace(text="42 degrees")
|
||||
tr_block = SimpleNamespace(toolUseId="call_1", content=[inner])
|
||||
msg = SimpleNamespace(
|
||||
role="user",
|
||||
content=[tr_block],
|
||||
content_as_list=[tr_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "tool"
|
||||
assert result[0]["tool_call_id"] == "call_1"
|
||||
assert result[0]["content"] == "42 degrees"
|
||||
|
||||
def test_tool_use_message(self):
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_2", name="get_weather", input={"city": "London"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[tu_block],
|
||||
content_as_list=[tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
assert result[0]["tool_calls"][0]["function"]["name"] == "get_weather"
|
||||
assert json.loads(result[0]["tool_calls"][0]["function"]["arguments"]) == {"city": "London"}
|
||||
|
||||
def test_mixed_text_and_tool_use(self):
|
||||
"""Assistant message with both text and tool_calls."""
|
||||
text_block = SimpleNamespace(text="Let me check the weather")
|
||||
tu_block = SimpleNamespace(
|
||||
id="call_3", name="get_weather", input={"city": "Paris"}
|
||||
)
|
||||
msg = SimpleNamespace(
|
||||
role="assistant",
|
||||
content=[text_block, tu_block],
|
||||
content_as_list=[text_block, tu_block],
|
||||
)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Let me check the weather"
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
|
||||
def test_fallback_without_content_as_list(self):
|
||||
"""When content_as_list is absent, falls back to content."""
|
||||
content = SimpleNamespace(text="Fallback text")
|
||||
msg = SimpleNamespace(role="user", content=content)
|
||||
params = _make_sampling_params(messages=[msg])
|
||||
result = self.handler._convert_messages(params)
|
||||
assert len(result) == 1
|
||||
assert result[0]["content"] == "Fallback text"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Text-only sampling callback (full flow)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackText:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("txt", {})
|
||||
|
||||
def test_text_response(self):
|
||||
"""Full flow: text response returns CreateMessageResult."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
content="Hello from LLM"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert isinstance(result.content, TextContent)
|
||||
assert result.content.text == "Hello from LLM"
|
||||
assert result.model == "test-model"
|
||||
assert result.role == "assistant"
|
||||
assert result.stopReason == "endTurn"
|
||||
|
||||
def test_system_prompt_prepended(self):
|
||||
"""System prompt is inserted as the first message."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params(system_prompt="Be helpful")
|
||||
asyncio.run(self.handler(None, params))
|
||||
|
||||
call_args = fake_client.chat.completions.create.call_args
|
||||
messages = call_args.kwargs["messages"]
|
||||
assert messages[0] == {"role": "system", "content": "Be helpful"}
|
||||
|
||||
def test_length_stop_reason(self):
|
||||
"""finish_reason='length' maps to stopReason='maxTokens'."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response(
|
||||
finish_reason="length"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
assert result.stopReason == "maxTokens"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Tool use sampling callback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingCallbackToolUse:
|
||||
def setup_method(self):
|
||||
self.handler = SamplingHandler("tu", {})
|
||||
|
||||
def test_tool_use_response(self):
|
||||
"""LLM tool_calls response returns CreateMessageResultWithTools."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
result = asyncio.run(self.handler(None, params))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.stopReason == "toolUse"
|
||||
assert result.model == "test-model"
|
||||
assert len(result.content) == 1
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.name == "get_weather"
|
||||
assert tc.id == "call_1"
|
||||
assert tc.input == {"city": "London"}
|
||||
|
||||
def test_multiple_tool_calls(self):
|
||||
"""Multiple tool_calls in a single response."""
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[
|
||||
("call_a", "func_a", '{"x": 1}'),
|
||||
("call_b", "func_b", '{"y": 2}'),
|
||||
]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(self.handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert len(result.content) == 2
|
||||
assert result.content[0].name == "func_a"
|
||||
assert result.content[1].name == "func_b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Tool loop governance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestToolLoopGovernance:
|
||||
def test_max_tool_rounds_enforcement(self):
|
||||
"""After max_tool_rounds consecutive tool responses, an error is returned."""
|
||||
handler = SamplingHandler("tl", {"max_tool_rounds": 2})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
params = _make_sampling_params()
|
||||
# Round 1, 2: allowed
|
||||
r1 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
r2 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r2, CreateMessageResultWithTools)
|
||||
# Round 3: exceeds limit
|
||||
r3 = asyncio.run(handler(None, params))
|
||||
assert isinstance(r3, ErrorData)
|
||||
assert "Tool loop limit exceeded" in r3.message
|
||||
|
||||
def test_text_response_resets_counter(self):
|
||||
"""A text response resets the tool loop counter."""
|
||||
handler = SamplingHandler("tl2", {"max_tool_rounds": 1})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# Tool response (round 1 of 1 allowed)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResultWithTools)
|
||||
|
||||
# Text response resets counter
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, CreateMessageResult)
|
||||
|
||||
# Tool response again (should succeed since counter was reset)
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
r3 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r3, CreateMessageResultWithTools)
|
||||
|
||||
def test_max_tool_rounds_zero_disables(self):
|
||||
"""max_tool_rounds=0 means tool loops are disabled entirely."""
|
||||
handler = SamplingHandler("tl3", {"max_tool_rounds": 0})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "Tool loops disabled" in result.message
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. Error paths: rate limit, timeout, no provider
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSamplingErrors:
|
||||
def test_rate_limit_error(self):
|
||||
handler = SamplingHandler("rle", {"max_rpm": 1})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
# First call succeeds
|
||||
r1 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r1, CreateMessageResult)
|
||||
# Second call is rate limited
|
||||
r2 = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(r2, ErrorData)
|
||||
assert "rate limit" in r2.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_timeout_error(self):
|
||||
handler = SamplingHandler("to", {"timeout": 0.05})
|
||||
fake_client = MagicMock()
|
||||
|
||||
def slow_call(**kwargs):
|
||||
import threading
|
||||
# Use an event to ensure the thread truly blocks long enough
|
||||
evt = threading.Event()
|
||||
evt.wait(5) # blocks for up to 5 seconds (cancelled by timeout)
|
||||
return _make_llm_response()
|
||||
|
||||
fake_client.chat.completions.create.side_effect = slow_call
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "timed out" in result.message.lower()
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_no_provider_error(self):
|
||||
handler = SamplingHandler("np", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "No LLM provider" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Model whitelist
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestModelWhitelist:
|
||||
def test_allowed_model_passes(self):
|
||||
handler = SamplingHandler("wl", {"allowed_models": ["gpt-4o", "test-model"]})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "test-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
def test_disallowed_model_rejected(self):
|
||||
handler = SamplingHandler("wl2", {"allowed_models": ["gpt-4o"]})
|
||||
fake_client = MagicMock()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "gpt-3.5-turbo"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, ErrorData)
|
||||
assert "not allowed" in result.message
|
||||
assert handler.metrics["errors"] == 1
|
||||
|
||||
def test_empty_whitelist_allows_all(self):
|
||||
handler = SamplingHandler("wl3", {"allowed_models": []})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "any-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
assert isinstance(result, CreateMessageResult)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. Malformed tool_call arguments
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMalformedToolCallArgs:
|
||||
def test_invalid_json_wrapped_as_raw(self):
|
||||
"""Malformed JSON arguments get wrapped in {"_raw": ...}."""
|
||||
handler = SamplingHandler("mf", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response(
|
||||
tool_calls_data=[("call_x", "some_tool", "not valid json {{{")]
|
||||
)
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
tc = result.content[0]
|
||||
assert isinstance(tc, ToolUseContent)
|
||||
assert tc.input == {"_raw": "not valid json {{{"}
|
||||
|
||||
def test_dict_args_pass_through(self):
|
||||
"""When arguments are already a dict, they pass through directly."""
|
||||
handler = SamplingHandler("mf2", {})
|
||||
|
||||
# Build a tool call where arguments is already a dict
|
||||
tc_obj = SimpleNamespace(
|
||||
id="call_d",
|
||||
function=SimpleNamespace(name="do_stuff", arguments={"key": "val"}),
|
||||
)
|
||||
message = SimpleNamespace(content=None, tool_calls=[tc_obj])
|
||||
choice = SimpleNamespace(finish_reason="tool_calls", message=message)
|
||||
usage = SimpleNamespace(total_tokens=10)
|
||||
response = SimpleNamespace(choices=[choice], model="m", usage=usage)
|
||||
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = response
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
result = asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert isinstance(result, CreateMessageResultWithTools)
|
||||
assert result.content[0].input == {"key": "val"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. Metrics tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMetricsTracking:
|
||||
def test_request_and_token_metrics(self):
|
||||
handler = SamplingHandler("met", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["requests"] == 1
|
||||
assert handler.metrics["tokens_used"] == 42
|
||||
assert handler.metrics["errors"] == 0
|
||||
|
||||
def test_tool_use_count_metric(self):
|
||||
handler = SamplingHandler("met2", {})
|
||||
fake_client = MagicMock()
|
||||
fake_client.chat.completions.create.return_value = _make_llm_tool_response()
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(fake_client, "default-model"),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["tool_use_count"] == 1
|
||||
assert handler.metrics["requests"] == 1
|
||||
|
||||
def test_error_metric_incremented(self):
|
||||
handler = SamplingHandler("met3", {})
|
||||
|
||||
with patch(
|
||||
"agent.auxiliary_client.get_text_auxiliary_client",
|
||||
return_value=(None, None),
|
||||
):
|
||||
asyncio.run(handler(None, _make_sampling_params()))
|
||||
|
||||
assert handler.metrics["errors"] == 1
|
||||
assert handler.metrics["requests"] == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. session_kwargs()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSessionKwargs:
|
||||
def test_returns_correct_keys(self):
|
||||
handler = SamplingHandler("sk", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
assert kwargs["sampling_callback"] is handler
|
||||
|
||||
def test_sampling_capabilities_type(self):
|
||||
handler = SamplingHandler("sk2", {})
|
||||
kwargs = handler.session_kwargs()
|
||||
cap = kwargs["sampling_capabilities"]
|
||||
assert isinstance(cap, SamplingCapability)
|
||||
assert isinstance(cap.tools, SamplingToolsCapability)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. MCPServerTask integration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestMCPServerTaskSamplingIntegration:
|
||||
def test_sampling_handler_created_when_enabled(self):
|
||||
"""MCPServerTask.run() creates a SamplingHandler when sampling is enabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": True, "max_rpm": 5},
|
||||
}
|
||||
# We only need to test the setup logic, not the actual connection.
|
||||
# Calling run() would attempt a real connection, so we test the
|
||||
# sampling setup portion directly.
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is not None
|
||||
assert isinstance(server._sampling, SamplingHandler)
|
||||
assert server._sampling.server_name == "int_test"
|
||||
assert server._sampling.max_rpm == 5
|
||||
|
||||
def test_sampling_handler_none_when_disabled(self):
|
||||
"""MCPServerTask._sampling is None when sampling is disabled."""
|
||||
from tools.mcp_tool import MCPServerTask, _MCP_SAMPLING_TYPES
|
||||
|
||||
server = MCPServerTask("int_test2")
|
||||
config = {
|
||||
"command": "fake",
|
||||
"sampling": {"enabled": False},
|
||||
}
|
||||
server._config = config
|
||||
sampling_config = config.get("sampling", {})
|
||||
if sampling_config.get("enabled", True) and _MCP_SAMPLING_TYPES:
|
||||
server._sampling = SamplingHandler(server.name, sampling_config)
|
||||
else:
|
||||
server._sampling = None
|
||||
|
||||
assert server._sampling is None
|
||||
|
||||
def test_session_kwargs_used_in_stdio(self):
|
||||
"""When sampling is set, session_kwargs() are passed to ClientSession."""
|
||||
from tools.mcp_tool import MCPServerTask
|
||||
|
||||
server = MCPServerTask("sk_test")
|
||||
server._sampling = SamplingHandler("sk_test", {"max_rpm": 7})
|
||||
kwargs = server._sampling.session_kwargs()
|
||||
assert "sampling_callback" in kwargs
|
||||
assert "sampling_capabilities" in kwargs
|
||||
|
||||
Reference in New Issue
Block a user