Add tests for atropos tool calling integration
- test_tool_call_parsers.py: 16 tests for parser registry, hermes parser (single/multiple/truncated/malformed), and ParseResult contract validation - test_agent_loop.py: 21 tests for HermesAgentLoop with mock servers (text responses, tool calls, max turns, unknown tools, API errors, extra_body forwarding, managed state, blocked tools, reasoning extraction) - test_managed_server_tool_support.py: 9 tests validating API compatibility between hermes-agent and atroposlib's ManagedServer tool_call_parser support (gracefully skips on baseline atroposlib, passes on tool_call_support branch)
This commit is contained in:
483
tests/test_agent_loop.py
Normal file
483
tests/test_agent_loop.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Tests for environments/agent_loop.py — HermesAgentLoop.
|
||||
|
||||
Tests the multi-turn agent engine using mocked servers, without needing
|
||||
real API keys or running servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from environments.agent_loop import (
|
||||
AgentResult,
|
||||
HermesAgentLoop,
|
||||
ToolError,
|
||||
_extract_reasoning_from_message,
|
||||
resize_tool_pool,
|
||||
)
|
||||
|
||||
|
||||
# ─── Mock server infrastructure ─────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockFunction:
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockToolCall:
|
||||
id: str
|
||||
function: MockFunction
|
||||
type: str = "function"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockMessage:
|
||||
content: Optional[str]
|
||||
role: str = "assistant"
|
||||
tool_calls: Optional[List[MockToolCall]] = None
|
||||
reasoning_content: Optional[str] = None
|
||||
reasoning: Optional[str] = None
|
||||
reasoning_details: Optional[list] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChoice:
|
||||
message: MockMessage
|
||||
finish_reason: str = "stop"
|
||||
index: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockChatCompletion:
|
||||
choices: List[MockChoice]
|
||||
id: str = "chatcmpl-mock"
|
||||
model: str = "mock-model"
|
||||
|
||||
|
||||
class MockServer:
|
||||
"""
|
||||
Mock server that returns pre-configured responses in sequence.
|
||||
Mimics the chat_completion() interface.
|
||||
"""
|
||||
|
||||
def __init__(self, responses: List[MockChatCompletion]):
|
||||
self.responses = responses
|
||||
self.call_count = 0
|
||||
self.call_history: List[Dict[str, Any]] = []
|
||||
|
||||
async def chat_completion(self, **kwargs) -> MockChatCompletion:
|
||||
self.call_history.append(kwargs)
|
||||
if self.call_count >= len(self.responses):
|
||||
# Return a simple text response if we run out
|
||||
return MockChatCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content="Done."))]
|
||||
)
|
||||
resp = self.responses[self.call_count]
|
||||
self.call_count += 1
|
||||
return resp
|
||||
|
||||
|
||||
def make_text_response(content: str) -> MockChatCompletion:
|
||||
"""Create a simple text-only response (no tool calls)."""
|
||||
return MockChatCompletion(
|
||||
choices=[MockChoice(message=MockMessage(content=content))]
|
||||
)
|
||||
|
||||
|
||||
def make_tool_response(
|
||||
tool_name: str,
|
||||
arguments: dict,
|
||||
content: str = "",
|
||||
tool_call_id: str = "call_001",
|
||||
) -> MockChatCompletion:
|
||||
"""Create a response with a single tool call."""
|
||||
return MockChatCompletion(
|
||||
choices=[
|
||||
MockChoice(
|
||||
message=MockMessage(
|
||||
content=content,
|
||||
tool_calls=[
|
||||
MockToolCall(
|
||||
id=tool_call_id,
|
||||
function=MockFunction(
|
||||
name=tool_name,
|
||||
arguments=json.dumps(arguments),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
finish_reason="tool_calls",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# ─── Tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
def test_defaults(self):
|
||||
result = AgentResult(messages=[])
|
||||
assert result.messages == []
|
||||
assert result.managed_state is None
|
||||
assert result.turns_used == 0
|
||||
assert result.finished_naturally is False
|
||||
assert result.reasoning_per_turn == []
|
||||
assert result.tool_errors == []
|
||||
|
||||
|
||||
class TestExtractReasoning:
|
||||
def test_reasoning_content_field(self):
|
||||
msg = MockMessage(content="hello", reasoning_content="I think...")
|
||||
assert _extract_reasoning_from_message(msg) == "I think..."
|
||||
|
||||
def test_reasoning_field(self):
|
||||
msg = MockMessage(content="hello", reasoning="Let me consider...")
|
||||
assert _extract_reasoning_from_message(msg) == "Let me consider..."
|
||||
|
||||
def test_reasoning_details(self):
|
||||
detail = MagicMock()
|
||||
detail.text = "Detail reasoning"
|
||||
msg = MockMessage(content="hello", reasoning_details=[detail])
|
||||
assert _extract_reasoning_from_message(msg) == "Detail reasoning"
|
||||
|
||||
def test_reasoning_details_dict_format(self):
|
||||
msg = MockMessage(
|
||||
content="hello",
|
||||
reasoning_details=[{"text": "Dict reasoning"}],
|
||||
)
|
||||
assert _extract_reasoning_from_message(msg) == "Dict reasoning"
|
||||
|
||||
def test_no_reasoning(self):
|
||||
msg = MockMessage(content="hello")
|
||||
assert _extract_reasoning_from_message(msg) is None
|
||||
|
||||
def test_reasoning_content_takes_priority(self):
|
||||
msg = MockMessage(
|
||||
content="hello",
|
||||
reasoning_content="First",
|
||||
reasoning="Second",
|
||||
)
|
||||
assert _extract_reasoning_from_message(msg) == "First"
|
||||
|
||||
|
||||
class TestHermesAgentLoop:
|
||||
"""Test the agent loop with mock servers."""
|
||||
|
||||
@pytest.fixture
|
||||
def basic_tools(self):
|
||||
"""Minimal tool schema for testing."""
|
||||
return [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "terminal",
|
||||
"description": "Run a command",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "Command to run",
|
||||
}
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "read_file",
|
||||
"description": "Read a file",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"},
|
||||
},
|
||||
"required": ["path"],
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@pytest.fixture
|
||||
def valid_names(self):
|
||||
return {"terminal", "read_file", "todo"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_text_response(self, basic_tools, valid_names):
|
||||
"""Model responds with text only, no tool calls."""
|
||||
server = MockServer([make_text_response("Hello! How can I help?")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is True
|
||||
assert result.turns_used == 1
|
||||
assert len(result.messages) >= 2 # user + assistant
|
||||
assert result.messages[-1]["role"] == "assistant"
|
||||
assert result.messages[-1]["content"] == "Hello! How can I help?"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_then_text(self, basic_tools, valid_names):
|
||||
"""Model calls a tool, then responds with text."""
|
||||
server = MockServer([
|
||||
make_tool_response("todo", {"todos": [{"id": "1", "content": "test", "status": "pending"}]}),
|
||||
make_text_response("I created a todo for you."),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Create a todo"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is True
|
||||
assert result.turns_used == 2
|
||||
# Should have: user, assistant (tool_call), tool (result), assistant (text)
|
||||
roles = [m["role"] for m in result.messages]
|
||||
assert roles == ["user", "assistant", "tool", "assistant"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_turns_reached(self, basic_tools, valid_names):
|
||||
"""Model keeps calling tools until max_turns is hit."""
|
||||
# Create responses that always call a tool
|
||||
responses = [
|
||||
make_tool_response("todo", {"todos": [{"id": str(i), "content": f"task {i}", "status": "pending"}]}, tool_call_id=f"call_{i}")
|
||||
for i in range(10)
|
||||
]
|
||||
server = MockServer(responses)
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=3,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Keep going"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_name(self, basic_tools, valid_names):
|
||||
"""Model calls a tool not in valid_tool_names."""
|
||||
server = MockServer([
|
||||
make_tool_response("nonexistent_tool", {"arg": "val"}),
|
||||
make_text_response("OK, that didn't work."),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Call something weird"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Should record a tool error
|
||||
assert len(result.tool_errors) >= 1
|
||||
assert result.tool_errors[0].tool_name == "nonexistent_tool"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_response(self, basic_tools, valid_names):
|
||||
"""Server returns empty response."""
|
||||
server = MockServer([MockChatCompletion(choices=[])])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_error_handling(self, basic_tools, valid_names):
|
||||
"""Server raises an exception."""
|
||||
|
||||
class FailingServer:
|
||||
async def chat_completion(self, **kwargs):
|
||||
raise ConnectionError("Server unreachable")
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=FailingServer(),
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.finished_naturally is False
|
||||
assert result.turns_used == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tools_passed_to_server(self, basic_tools, valid_names):
|
||||
"""Verify tools are passed in the chat_completion kwargs."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
await agent.run(messages)
|
||||
|
||||
assert len(server.call_history) == 1
|
||||
assert "tools" in server.call_history[0]
|
||||
assert server.call_history[0]["tools"] == basic_tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extra_body_forwarded(self, basic_tools, valid_names):
|
||||
"""extra_body should be forwarded to server."""
|
||||
extra = {"provider": {"ignore": ["DeepInfra"]}}
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
extra_body=extra,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
await agent.run(messages)
|
||||
|
||||
assert server.call_history[0].get("extra_body") == extra
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_managed_state_returned(self, basic_tools, valid_names):
|
||||
"""If server has get_state(), result should include managed_state."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
server.get_state = lambda: {"nodes": [{"test": True}]}
|
||||
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.managed_state is not None
|
||||
assert "nodes" in result.managed_state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_managed_state_without_get_state(self, basic_tools, valid_names):
|
||||
"""Regular server without get_state() should return None managed_state."""
|
||||
server = MockServer([make_text_response("OK")])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Hi"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert result.managed_state is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_tool_blocked(self, basic_tools):
|
||||
"""Memory tool should return error in RL environments."""
|
||||
valid = {"terminal", "read_file", "todo", "memory"}
|
||||
server = MockServer([
|
||||
make_tool_response("memory", {"action": "add", "target": "user", "content": "test"}),
|
||||
make_text_response("Done"),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Remember this"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
# Find the tool response
|
||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) >= 1
|
||||
tool_result = json.loads(tool_msgs[0]["content"])
|
||||
assert "error" in tool_result
|
||||
assert "not available" in tool_result["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_search_blocked(self, basic_tools):
|
||||
"""session_search should return error in RL environments."""
|
||||
valid = {"terminal", "read_file", "todo", "session_search"}
|
||||
server = MockServer([
|
||||
make_tool_response("session_search", {"query": "test"}),
|
||||
make_text_response("Done"),
|
||||
])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "Search sessions"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
tool_msgs = [m for m in result.messages if m["role"] == "tool"]
|
||||
assert len(tool_msgs) >= 1
|
||||
tool_result = json.loads(tool_msgs[0]["content"])
|
||||
assert "error" in tool_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_content_preserved(self, basic_tools, valid_names):
|
||||
"""Reasoning content should be extracted and preserved."""
|
||||
resp = MockChatCompletion(
|
||||
choices=[
|
||||
MockChoice(
|
||||
message=MockMessage(
|
||||
content="The answer is 42.",
|
||||
reasoning_content="Let me think about this step by step...",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
server = MockServer([resp])
|
||||
agent = HermesAgentLoop(
|
||||
server=server,
|
||||
tool_schemas=basic_tools,
|
||||
valid_tool_names=valid_names,
|
||||
max_turns=10,
|
||||
)
|
||||
messages = [{"role": "user", "content": "What is the meaning of life?"}]
|
||||
result = await agent.run(messages)
|
||||
|
||||
assert len(result.reasoning_per_turn) == 1
|
||||
assert result.reasoning_per_turn[0] == "Let me think about this step by step..."
|
||||
|
||||
|
||||
class TestResizeToolPool:
|
||||
def test_resize_works(self):
|
||||
"""resize_tool_pool should not raise."""
|
||||
resize_tool_pool(16) # Small pool for testing
|
||||
resize_tool_pool(128) # Restore default
|
||||
173
tests/test_managed_server_tool_support.py
Normal file
173
tests/test_managed_server_tool_support.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""
|
||||
Tests for ManagedServer tool_call_parser integration.
|
||||
|
||||
Validates that:
|
||||
1. ManagedServer accepts tool_call_parser parameter (tool_call_support branch)
|
||||
2. ServerManager.managed_server() passes tool_call_parser through
|
||||
3. The parser's parse() output is correctly attached to ChatCompletion responses
|
||||
4. hermes-agent's tool_call_parsers are compatible with ManagedServer's expectations
|
||||
|
||||
These tests verify the contract between hermes-agent's environments/ code
|
||||
and atroposlib's ManagedServer. They detect API incompatibilities early.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
|
||||
class TestManagedServerAPI:
|
||||
"""Test that ManagedServer's API matches what hermes-agent expects."""
|
||||
|
||||
def test_managed_server_init_signature(self):
|
||||
"""ManagedServer should accept tool_call_parser parameter."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
sig = inspect.signature(ManagedServer.__init__)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
# Core params that must exist
|
||||
assert "self" in params
|
||||
assert "server" in params
|
||||
assert "tokenizer" in params
|
||||
assert "track_tree" in params
|
||||
|
||||
# tool_call_parser — required for tool_call_support branch
|
||||
# If this fails, atroposlib hasn't been updated to tool_call_support
|
||||
has_tool_parser = "tool_call_parser" in params
|
||||
if not has_tool_parser:
|
||||
pytest.skip(
|
||||
"ManagedServer does not have tool_call_parser param — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
def test_server_manager_managed_server_signature(self):
|
||||
"""ServerManager.managed_server() should accept tool_call_parser."""
|
||||
from atroposlib.envs.server_handling.server_manager import ServerManager
|
||||
|
||||
sig = inspect.signature(ServerManager.managed_server)
|
||||
params = list(sig.parameters.keys())
|
||||
|
||||
assert "self" in params
|
||||
assert "tokenizer" in params
|
||||
|
||||
has_tool_parser = "tool_call_parser" in params
|
||||
if not has_tool_parser:
|
||||
pytest.skip(
|
||||
"ServerManager.managed_server() does not have tool_call_parser param — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
def test_managed_server_chat_template_kwargs(self):
|
||||
"""ManagedServer should have CHAT_TEMPLATE_KWARGS for forwarding tools/thinking."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
if not hasattr(ManagedServer, "CHAT_TEMPLATE_KWARGS"):
|
||||
pytest.skip(
|
||||
"ManagedServer does not have CHAT_TEMPLATE_KWARGS — "
|
||||
"baseline atroposlib (pre tool_call_support branch)"
|
||||
)
|
||||
|
||||
kwargs = ManagedServer.CHAT_TEMPLATE_KWARGS
|
||||
assert "tools" in kwargs, "tools must be in CHAT_TEMPLATE_KWARGS"
|
||||
|
||||
def test_no_get_logprobs_method(self):
|
||||
"""get_logprobs should be removed in tool_call_support branch."""
|
||||
from atroposlib.envs.server_handling.managed_server import ManagedServer
|
||||
|
||||
# In baseline, get_logprobs exists. In tool_call_support, it's removed.
|
||||
# We just note the state — not a hard fail either way.
|
||||
has_get_logprobs = hasattr(ManagedServer, "get_logprobs")
|
||||
if has_get_logprobs:
|
||||
pytest.skip(
|
||||
"ManagedServer still has get_logprobs — baseline atroposlib"
|
||||
)
|
||||
|
||||
|
||||
class TestParserCompatibility:
|
||||
"""Test that hermes-agent's parsers match ManagedServer's expectations."""
|
||||
|
||||
def test_parser_parse_returns_correct_format(self):
|
||||
"""
|
||||
ManagedServer expects parser.parse(text) -> (content, tool_calls)
|
||||
where tool_calls is a list of objects with .id, .function.name, .function.arguments
|
||||
"""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
|
||||
tc = tool_calls[0]
|
||||
# ManagedServer accesses these attrs directly
|
||||
assert hasattr(tc, "id")
|
||||
assert hasattr(tc, "function")
|
||||
assert hasattr(tc.function, "name")
|
||||
assert hasattr(tc.function, "arguments")
|
||||
|
||||
def test_parser_no_tools_returns_none(self):
|
||||
"""ManagedServer checks `if parsed_tool_calls:` — None should be falsy."""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
content, tool_calls = parser.parse("Just text, no tools")
|
||||
assert tool_calls is None
|
||||
|
||||
def test_parser_content_is_string_or_none(self):
|
||||
"""ManagedServer uses `parsed_content or ""` — must be str or None."""
|
||||
from environments.tool_call_parsers import get_parser
|
||||
|
||||
parser = get_parser("hermes")
|
||||
|
||||
# With tool calls
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>'
|
||||
content, _ = parser.parse(text)
|
||||
assert content is None or isinstance(content, str)
|
||||
|
||||
# Without tool calls
|
||||
content2, _ = parser.parse("Just text")
|
||||
assert isinstance(content2, str)
|
||||
|
||||
|
||||
class TestBaseEnvCompatibility:
|
||||
"""Test that hermes_base_env.py's managed_server() call matches the API."""
|
||||
|
||||
def test_hermes_base_env_managed_server_call_pattern(self):
|
||||
"""
|
||||
Verify that hermes_base_env.py passes tool_call_parser to managed_server().
|
||||
This is a source-level check — the actual managed_server() call must match.
|
||||
"""
|
||||
import ast
|
||||
|
||||
base_env_path = Path(__file__).parent.parent / "environments" / "hermes_base_env.py"
|
||||
source = base_env_path.read_text()
|
||||
tree = ast.parse(source)
|
||||
|
||||
# Find the managed_server() call
|
||||
found_tool_call_parser_kwarg = False
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.Call):
|
||||
# Look for self.server.managed_server(...)
|
||||
if isinstance(node.func, ast.Attribute) and node.func.attr == "managed_server":
|
||||
for kw in node.keywords:
|
||||
if kw.arg == "tool_call_parser":
|
||||
found_tool_call_parser_kwarg = True
|
||||
|
||||
assert found_tool_call_parser_kwarg, (
|
||||
"hermes_base_env.py should pass tool_call_parser= to managed_server()"
|
||||
)
|
||||
|
||||
def test_hermes_base_env_uses_get_parser(self):
|
||||
"""Verify hermes_base_env imports and uses get_parser from tool_call_parsers."""
|
||||
base_env_path = Path(__file__).parent.parent / "environments" / "hermes_base_env.py"
|
||||
source = base_env_path.read_text()
|
||||
|
||||
assert "from environments.tool_call_parsers import get_parser" in source
|
||||
assert "get_parser(" in source
|
||||
156
tests/test_tool_call_parsers.py
Normal file
156
tests/test_tool_call_parsers.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Tests for environments/tool_call_parsers/ — client-side tool call parsers.
|
||||
|
||||
These parsers extract structured tool_calls from raw model output text.
|
||||
Used in Phase 2 (VLLM/generate) where the server returns raw tokens.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Ensure repo root is importable
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from environments.tool_call_parsers import (
|
||||
ParseResult,
|
||||
ToolCallParser,
|
||||
get_parser,
|
||||
list_parsers,
|
||||
)
|
||||
|
||||
|
||||
# ─── Registry tests ─────────────────────────────────────────────────────
|
||||
|
||||
class TestParserRegistry:
|
||||
def test_list_parsers_returns_nonempty(self):
|
||||
parsers = list_parsers()
|
||||
assert len(parsers) > 0
|
||||
|
||||
def test_hermes_parser_registered(self):
|
||||
parsers = list_parsers()
|
||||
assert "hermes" in parsers
|
||||
|
||||
def test_get_parser_returns_instance(self):
|
||||
parser = get_parser("hermes")
|
||||
assert isinstance(parser, ToolCallParser)
|
||||
|
||||
def test_get_parser_unknown_raises(self):
|
||||
with pytest.raises(KeyError):
|
||||
get_parser("nonexistent_parser_xyz")
|
||||
|
||||
def test_all_registered_parsers_instantiate(self):
|
||||
"""Every registered parser should be importable and instantiable."""
|
||||
for name in list_parsers():
|
||||
parser = get_parser(name)
|
||||
assert isinstance(parser, ToolCallParser)
|
||||
assert hasattr(parser, "parse")
|
||||
|
||||
|
||||
# ─── Hermes parser tests ────────────────────────────────────────────────
|
||||
|
||||
class TestHermesParser:
|
||||
@pytest.fixture
|
||||
def parser(self):
|
||||
return get_parser("hermes")
|
||||
|
||||
def test_no_tool_call(self, parser):
|
||||
text = "Hello, I can help you with that."
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert content == text
|
||||
assert tool_calls is None
|
||||
|
||||
def test_single_tool_call(self, parser):
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "terminal"
|
||||
args = json.loads(tool_calls[0].function.arguments)
|
||||
assert args["command"] == "ls -la"
|
||||
|
||||
def test_tool_call_with_surrounding_text(self, parser):
|
||||
text = 'Let me check that for you.\n<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].function.name == "terminal"
|
||||
# Content should have the surrounding text
|
||||
if content is not None:
|
||||
assert "check that" in content or content.strip() != ""
|
||||
|
||||
def test_multiple_tool_calls(self, parser):
|
||||
text = (
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "read_file", "arguments": {"path": "test.py"}}</tool_call>'
|
||||
)
|
||||
content, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
assert len(tool_calls) == 2
|
||||
names = {tc.function.name for tc in tool_calls}
|
||||
assert "terminal" in names
|
||||
assert "read_file" in names
|
||||
|
||||
def test_tool_call_ids_are_unique(self, parser):
|
||||
text = (
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "ls"}}</tool_call>\n'
|
||||
'<tool_call>{"name": "terminal", "arguments": {"command": "pwd"}}</tool_call>'
|
||||
)
|
||||
_, tool_calls = parser.parse(text)
|
||||
assert tool_calls is not None
|
||||
ids = [tc.id for tc in tool_calls]
|
||||
assert len(ids) == len(set(ids)), "Tool call IDs must be unique"
|
||||
|
||||
def test_empty_string(self, parser):
|
||||
content, tool_calls = parser.parse("")
|
||||
assert tool_calls is None
|
||||
|
||||
def test_malformed_json_in_tool_call(self, parser):
|
||||
text = '<tool_call>not valid json</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
# Should either return None tool_calls or handle gracefully
|
||||
# (implementation may vary — some parsers return error tool calls)
|
||||
|
||||
def test_truncated_tool_call(self, parser):
|
||||
"""Test handling of unclosed tool_call tag (model truncated mid-generation)."""
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "ls -la"}'
|
||||
content, tool_calls = parser.parse(text)
|
||||
# Parser should handle truncated output gracefully
|
||||
# Either parse it successfully or return None
|
||||
|
||||
|
||||
# ─── Parse result contract tests (applies to ALL parsers) ───────────────
|
||||
|
||||
class TestParseResultContract:
|
||||
"""Ensure all parsers conform to the ParseResult contract."""
|
||||
|
||||
@pytest.fixture(params=["hermes"]) # Add more as needed
|
||||
def parser(self, request):
|
||||
return get_parser(request.param)
|
||||
|
||||
def test_returns_tuple_of_two(self, parser):
|
||||
result = parser.parse("hello world")
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_no_tools_returns_none_tool_calls(self, parser):
|
||||
content, tool_calls = parser.parse("Just plain text, no tools.")
|
||||
assert tool_calls is None
|
||||
assert content is not None
|
||||
|
||||
def test_tool_calls_are_proper_objects(self, parser):
|
||||
"""When tool calls are found, they should be ChatCompletionMessageToolCall objects."""
|
||||
# Use hermes format since that's universal
|
||||
text = '<tool_call>{"name": "terminal", "arguments": {"command": "echo hi"}}</tool_call>'
|
||||
content, tool_calls = parser.parse(text)
|
||||
if tool_calls is not None:
|
||||
for tc in tool_calls:
|
||||
assert hasattr(tc, "id")
|
||||
assert hasattr(tc, "function")
|
||||
assert hasattr(tc.function, "name")
|
||||
assert hasattr(tc.function, "arguments")
|
||||
assert tc.id is not None
|
||||
assert isinstance(tc.function.name, str)
|
||||
assert isinstance(tc.function.arguments, str)
|
||||
Reference in New Issue
Block a user