diff --git a/tests/test_agent_loop.py b/tests/test_agent_loop.py new file mode 100644 index 000000000..22629b88e --- /dev/null +++ b/tests/test_agent_loop.py @@ -0,0 +1,483 @@ +""" +Tests for environments/agent_loop.py — HermesAgentLoop. + +Tests the multi-turn agent engine using mocked servers, without needing +real API keys or running servers. +""" + +import asyncio +import json +import sys +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock + +import pytest + +# Ensure repo root is importable +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from environments.agent_loop import ( + AgentResult, + HermesAgentLoop, + ToolError, + _extract_reasoning_from_message, + resize_tool_pool, +) + + +# ─── Mock server infrastructure ───────────────────────────────────────── + + +@dataclass +class MockFunction: + name: str + arguments: str + + +@dataclass +class MockToolCall: + id: str + function: MockFunction + type: str = "function" + + +@dataclass +class MockMessage: + content: Optional[str] + role: str = "assistant" + tool_calls: Optional[List[MockToolCall]] = None + reasoning_content: Optional[str] = None + reasoning: Optional[str] = None + reasoning_details: Optional[list] = None + + +@dataclass +class MockChoice: + message: MockMessage + finish_reason: str = "stop" + index: int = 0 + + +@dataclass +class MockChatCompletion: + choices: List[MockChoice] + id: str = "chatcmpl-mock" + model: str = "mock-model" + + +class MockServer: + """ + Mock server that returns pre-configured responses in sequence. + Mimics the chat_completion() interface. + """ + + def __init__(self, responses: List[MockChatCompletion]): + self.responses = responses + self.call_count = 0 + self.call_history: List[Dict[str, Any]] = [] + + async def chat_completion(self, **kwargs) -> MockChatCompletion: + self.call_history.append(kwargs) + if self.call_count >= len(self.responses): + # Return a simple text response if we run out + return MockChatCompletion( + choices=[MockChoice(message=MockMessage(content="Done."))] + ) + resp = self.responses[self.call_count] + self.call_count += 1 + return resp + + +def make_text_response(content: str) -> MockChatCompletion: + """Create a simple text-only response (no tool calls).""" + return MockChatCompletion( + choices=[MockChoice(message=MockMessage(content=content))] + ) + + +def make_tool_response( + tool_name: str, + arguments: dict, + content: str = "", + tool_call_id: str = "call_001", +) -> MockChatCompletion: + """Create a response with a single tool call.""" + return MockChatCompletion( + choices=[ + MockChoice( + message=MockMessage( + content=content, + tool_calls=[ + MockToolCall( + id=tool_call_id, + function=MockFunction( + name=tool_name, + arguments=json.dumps(arguments), + ), + ) + ], + ), + finish_reason="tool_calls", + ) + ] + ) + + +# ─── Tests ─────────────────────────────────────────────────────────────── + + +class TestAgentResult: + def test_defaults(self): + result = AgentResult(messages=[]) + assert result.messages == [] + assert result.managed_state is None + assert result.turns_used == 0 + assert result.finished_naturally is False + assert result.reasoning_per_turn == [] + assert result.tool_errors == [] + + +class TestExtractReasoning: + def test_reasoning_content_field(self): + msg = MockMessage(content="hello", reasoning_content="I think...") + assert _extract_reasoning_from_message(msg) == "I think..." + + def test_reasoning_field(self): + msg = MockMessage(content="hello", reasoning="Let me consider...") + assert _extract_reasoning_from_message(msg) == "Let me consider..." + + def test_reasoning_details(self): + detail = MagicMock() + detail.text = "Detail reasoning" + msg = MockMessage(content="hello", reasoning_details=[detail]) + assert _extract_reasoning_from_message(msg) == "Detail reasoning" + + def test_reasoning_details_dict_format(self): + msg = MockMessage( + content="hello", + reasoning_details=[{"text": "Dict reasoning"}], + ) + assert _extract_reasoning_from_message(msg) == "Dict reasoning" + + def test_no_reasoning(self): + msg = MockMessage(content="hello") + assert _extract_reasoning_from_message(msg) is None + + def test_reasoning_content_takes_priority(self): + msg = MockMessage( + content="hello", + reasoning_content="First", + reasoning="Second", + ) + assert _extract_reasoning_from_message(msg) == "First" + + +class TestHermesAgentLoop: + """Test the agent loop with mock servers.""" + + @pytest.fixture + def basic_tools(self): + """Minimal tool schema for testing.""" + return [ + { + "type": "function", + "function": { + "name": "terminal", + "description": "Run a command", + "parameters": { + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Command to run", + } + }, + "required": ["command"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string"}, + }, + "required": ["path"], + }, + }, + }, + ] + + @pytest.fixture + def valid_names(self): + return {"terminal", "read_file", "todo"} + + @pytest.mark.asyncio + async def test_simple_text_response(self, basic_tools, valid_names): + """Model responds with text only, no tool calls.""" + server = MockServer([make_text_response("Hello! How can I help?")]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + result = await agent.run(messages) + + assert result.finished_naturally is True + assert result.turns_used == 1 + assert len(result.messages) >= 2 # user + assistant + assert result.messages[-1]["role"] == "assistant" + assert result.messages[-1]["content"] == "Hello! How can I help?" + + @pytest.mark.asyncio + async def test_tool_call_then_text(self, basic_tools, valid_names): + """Model calls a tool, then responds with text.""" + server = MockServer([ + make_tool_response("todo", {"todos": [{"id": "1", "content": "test", "status": "pending"}]}), + make_text_response("I created a todo for you."), + ]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Create a todo"}] + result = await agent.run(messages) + + assert result.finished_naturally is True + assert result.turns_used == 2 + # Should have: user, assistant (tool_call), tool (result), assistant (text) + roles = [m["role"] for m in result.messages] + assert roles == ["user", "assistant", "tool", "assistant"] + + @pytest.mark.asyncio + async def test_max_turns_reached(self, basic_tools, valid_names): + """Model keeps calling tools until max_turns is hit.""" + # Create responses that always call a tool + responses = [ + make_tool_response("todo", {"todos": [{"id": str(i), "content": f"task {i}", "status": "pending"}]}, tool_call_id=f"call_{i}") + for i in range(10) + ] + server = MockServer(responses) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=3, + ) + messages = [{"role": "user", "content": "Keep going"}] + result = await agent.run(messages) + + assert result.finished_naturally is False + assert result.turns_used == 3 + + @pytest.mark.asyncio + async def test_unknown_tool_name(self, basic_tools, valid_names): + """Model calls a tool not in valid_tool_names.""" + server = MockServer([ + make_tool_response("nonexistent_tool", {"arg": "val"}), + make_text_response("OK, that didn't work."), + ]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Call something weird"}] + result = await agent.run(messages) + + # Should record a tool error + assert len(result.tool_errors) >= 1 + assert result.tool_errors[0].tool_name == "nonexistent_tool" + + @pytest.mark.asyncio + async def test_empty_response(self, basic_tools, valid_names): + """Server returns empty response.""" + server = MockServer([MockChatCompletion(choices=[])]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + result = await agent.run(messages) + + assert result.finished_naturally is False + assert result.turns_used == 1 + + @pytest.mark.asyncio + async def test_api_error_handling(self, basic_tools, valid_names): + """Server raises an exception.""" + + class FailingServer: + async def chat_completion(self, **kwargs): + raise ConnectionError("Server unreachable") + + agent = HermesAgentLoop( + server=FailingServer(), + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + result = await agent.run(messages) + + assert result.finished_naturally is False + assert result.turns_used == 1 + + @pytest.mark.asyncio + async def test_tools_passed_to_server(self, basic_tools, valid_names): + """Verify tools are passed in the chat_completion kwargs.""" + server = MockServer([make_text_response("OK")]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + await agent.run(messages) + + assert len(server.call_history) == 1 + assert "tools" in server.call_history[0] + assert server.call_history[0]["tools"] == basic_tools + + @pytest.mark.asyncio + async def test_extra_body_forwarded(self, basic_tools, valid_names): + """extra_body should be forwarded to server.""" + extra = {"provider": {"ignore": ["DeepInfra"]}} + server = MockServer([make_text_response("OK")]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + extra_body=extra, + ) + messages = [{"role": "user", "content": "Hi"}] + await agent.run(messages) + + assert server.call_history[0].get("extra_body") == extra + + @pytest.mark.asyncio + async def test_managed_state_returned(self, basic_tools, valid_names): + """If server has get_state(), result should include managed_state.""" + server = MockServer([make_text_response("OK")]) + server.get_state = lambda: {"nodes": [{"test": True}]} + + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + result = await agent.run(messages) + + assert result.managed_state is not None + assert "nodes" in result.managed_state + + @pytest.mark.asyncio + async def test_no_managed_state_without_get_state(self, basic_tools, valid_names): + """Regular server without get_state() should return None managed_state.""" + server = MockServer([make_text_response("OK")]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "Hi"}] + result = await agent.run(messages) + + assert result.managed_state is None + + @pytest.mark.asyncio + async def test_memory_tool_blocked(self, basic_tools): + """Memory tool should return error in RL environments.""" + valid = {"terminal", "read_file", "todo", "memory"} + server = MockServer([ + make_tool_response("memory", {"action": "add", "target": "user", "content": "test"}), + make_text_response("Done"), + ]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid, + max_turns=10, + ) + messages = [{"role": "user", "content": "Remember this"}] + result = await agent.run(messages) + + # Find the tool response + tool_msgs = [m for m in result.messages if m["role"] == "tool"] + assert len(tool_msgs) >= 1 + tool_result = json.loads(tool_msgs[0]["content"]) + assert "error" in tool_result + assert "not available" in tool_result["error"].lower() + + @pytest.mark.asyncio + async def test_session_search_blocked(self, basic_tools): + """session_search should return error in RL environments.""" + valid = {"terminal", "read_file", "todo", "session_search"} + server = MockServer([ + make_tool_response("session_search", {"query": "test"}), + make_text_response("Done"), + ]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid, + max_turns=10, + ) + messages = [{"role": "user", "content": "Search sessions"}] + result = await agent.run(messages) + + tool_msgs = [m for m in result.messages if m["role"] == "tool"] + assert len(tool_msgs) >= 1 + tool_result = json.loads(tool_msgs[0]["content"]) + assert "error" in tool_result + + @pytest.mark.asyncio + async def test_reasoning_content_preserved(self, basic_tools, valid_names): + """Reasoning content should be extracted and preserved.""" + resp = MockChatCompletion( + choices=[ + MockChoice( + message=MockMessage( + content="The answer is 42.", + reasoning_content="Let me think about this step by step...", + ) + ) + ] + ) + server = MockServer([resp]) + agent = HermesAgentLoop( + server=server, + tool_schemas=basic_tools, + valid_tool_names=valid_names, + max_turns=10, + ) + messages = [{"role": "user", "content": "What is the meaning of life?"}] + result = await agent.run(messages) + + assert len(result.reasoning_per_turn) == 1 + assert result.reasoning_per_turn[0] == "Let me think about this step by step..." + + +class TestResizeToolPool: + def test_resize_works(self): + """resize_tool_pool should not raise.""" + resize_tool_pool(16) # Small pool for testing + resize_tool_pool(128) # Restore default diff --git a/tests/test_managed_server_tool_support.py b/tests/test_managed_server_tool_support.py new file mode 100644 index 000000000..00b0e94f0 --- /dev/null +++ b/tests/test_managed_server_tool_support.py @@ -0,0 +1,173 @@ +""" +Tests for ManagedServer tool_call_parser integration. + +Validates that: +1. ManagedServer accepts tool_call_parser parameter (tool_call_support branch) +2. ServerManager.managed_server() passes tool_call_parser through +3. The parser's parse() output is correctly attached to ChatCompletion responses +4. hermes-agent's tool_call_parsers are compatible with ManagedServer's expectations + +These tests verify the contract between hermes-agent's environments/ code +and atroposlib's ManagedServer. They detect API incompatibilities early. +""" + +import inspect +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + + +class TestManagedServerAPI: + """Test that ManagedServer's API matches what hermes-agent expects.""" + + def test_managed_server_init_signature(self): + """ManagedServer should accept tool_call_parser parameter.""" + from atroposlib.envs.server_handling.managed_server import ManagedServer + + sig = inspect.signature(ManagedServer.__init__) + params = list(sig.parameters.keys()) + + # Core params that must exist + assert "self" in params + assert "server" in params + assert "tokenizer" in params + assert "track_tree" in params + + # tool_call_parser — required for tool_call_support branch + # If this fails, atroposlib hasn't been updated to tool_call_support + has_tool_parser = "tool_call_parser" in params + if not has_tool_parser: + pytest.skip( + "ManagedServer does not have tool_call_parser param — " + "baseline atroposlib (pre tool_call_support branch)" + ) + + def test_server_manager_managed_server_signature(self): + """ServerManager.managed_server() should accept tool_call_parser.""" + from atroposlib.envs.server_handling.server_manager import ServerManager + + sig = inspect.signature(ServerManager.managed_server) + params = list(sig.parameters.keys()) + + assert "self" in params + assert "tokenizer" in params + + has_tool_parser = "tool_call_parser" in params + if not has_tool_parser: + pytest.skip( + "ServerManager.managed_server() does not have tool_call_parser param — " + "baseline atroposlib (pre tool_call_support branch)" + ) + + def test_managed_server_chat_template_kwargs(self): + """ManagedServer should have CHAT_TEMPLATE_KWARGS for forwarding tools/thinking.""" + from atroposlib.envs.server_handling.managed_server import ManagedServer + + if not hasattr(ManagedServer, "CHAT_TEMPLATE_KWARGS"): + pytest.skip( + "ManagedServer does not have CHAT_TEMPLATE_KWARGS — " + "baseline atroposlib (pre tool_call_support branch)" + ) + + kwargs = ManagedServer.CHAT_TEMPLATE_KWARGS + assert "tools" in kwargs, "tools must be in CHAT_TEMPLATE_KWARGS" + + def test_no_get_logprobs_method(self): + """get_logprobs should be removed in tool_call_support branch.""" + from atroposlib.envs.server_handling.managed_server import ManagedServer + + # In baseline, get_logprobs exists. In tool_call_support, it's removed. + # We just note the state — not a hard fail either way. + has_get_logprobs = hasattr(ManagedServer, "get_logprobs") + if has_get_logprobs: + pytest.skip( + "ManagedServer still has get_logprobs — baseline atroposlib" + ) + + +class TestParserCompatibility: + """Test that hermes-agent's parsers match ManagedServer's expectations.""" + + def test_parser_parse_returns_correct_format(self): + """ + ManagedServer expects parser.parse(text) -> (content, tool_calls) + where tool_calls is a list of objects with .id, .function.name, .function.arguments + """ + from environments.tool_call_parsers import get_parser + + parser = get_parser("hermes") + text = '{"name": "terminal", "arguments": {"command": "ls"}}' + content, tool_calls = parser.parse(text) + + assert tool_calls is not None + assert len(tool_calls) == 1 + + tc = tool_calls[0] + # ManagedServer accesses these attrs directly + assert hasattr(tc, "id") + assert hasattr(tc, "function") + assert hasattr(tc.function, "name") + assert hasattr(tc.function, "arguments") + + def test_parser_no_tools_returns_none(self): + """ManagedServer checks `if parsed_tool_calls:` — None should be falsy.""" + from environments.tool_call_parsers import get_parser + + parser = get_parser("hermes") + content, tool_calls = parser.parse("Just text, no tools") + assert tool_calls is None + + def test_parser_content_is_string_or_none(self): + """ManagedServer uses `parsed_content or ""` — must be str or None.""" + from environments.tool_call_parsers import get_parser + + parser = get_parser("hermes") + + # With tool calls + text = '{"name": "terminal", "arguments": {"command": "ls"}}' + content, _ = parser.parse(text) + assert content is None or isinstance(content, str) + + # Without tool calls + content2, _ = parser.parse("Just text") + assert isinstance(content2, str) + + +class TestBaseEnvCompatibility: + """Test that hermes_base_env.py's managed_server() call matches the API.""" + + def test_hermes_base_env_managed_server_call_pattern(self): + """ + Verify that hermes_base_env.py passes tool_call_parser to managed_server(). + This is a source-level check — the actual managed_server() call must match. + """ + import ast + + base_env_path = Path(__file__).parent.parent / "environments" / "hermes_base_env.py" + source = base_env_path.read_text() + tree = ast.parse(source) + + # Find the managed_server() call + found_tool_call_parser_kwarg = False + for node in ast.walk(tree): + if isinstance(node, ast.Call): + # Look for self.server.managed_server(...) + if isinstance(node.func, ast.Attribute) and node.func.attr == "managed_server": + for kw in node.keywords: + if kw.arg == "tool_call_parser": + found_tool_call_parser_kwarg = True + + assert found_tool_call_parser_kwarg, ( + "hermes_base_env.py should pass tool_call_parser= to managed_server()" + ) + + def test_hermes_base_env_uses_get_parser(self): + """Verify hermes_base_env imports and uses get_parser from tool_call_parsers.""" + base_env_path = Path(__file__).parent.parent / "environments" / "hermes_base_env.py" + source = base_env_path.read_text() + + assert "from environments.tool_call_parsers import get_parser" in source + assert "get_parser(" in source diff --git a/tests/test_tool_call_parsers.py b/tests/test_tool_call_parsers.py new file mode 100644 index 000000000..6a07a2267 --- /dev/null +++ b/tests/test_tool_call_parsers.py @@ -0,0 +1,156 @@ +""" +Tests for environments/tool_call_parsers/ — client-side tool call parsers. + +These parsers extract structured tool_calls from raw model output text. +Used in Phase 2 (VLLM/generate) where the server returns raw tokens. +""" + +import json +import sys +from pathlib import Path + +import pytest + +# Ensure repo root is importable +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) + +from environments.tool_call_parsers import ( + ParseResult, + ToolCallParser, + get_parser, + list_parsers, +) + + +# ─── Registry tests ───────────────────────────────────────────────────── + +class TestParserRegistry: + def test_list_parsers_returns_nonempty(self): + parsers = list_parsers() + assert len(parsers) > 0 + + def test_hermes_parser_registered(self): + parsers = list_parsers() + assert "hermes" in parsers + + def test_get_parser_returns_instance(self): + parser = get_parser("hermes") + assert isinstance(parser, ToolCallParser) + + def test_get_parser_unknown_raises(self): + with pytest.raises(KeyError): + get_parser("nonexistent_parser_xyz") + + def test_all_registered_parsers_instantiate(self): + """Every registered parser should be importable and instantiable.""" + for name in list_parsers(): + parser = get_parser(name) + assert isinstance(parser, ToolCallParser) + assert hasattr(parser, "parse") + + +# ─── Hermes parser tests ──────────────────────────────────────────────── + +class TestHermesParser: + @pytest.fixture + def parser(self): + return get_parser("hermes") + + def test_no_tool_call(self, parser): + text = "Hello, I can help you with that." + content, tool_calls = parser.parse(text) + assert content == text + assert tool_calls is None + + def test_single_tool_call(self, parser): + text = '{"name": "terminal", "arguments": {"command": "ls -la"}}' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "terminal" + args = json.loads(tool_calls[0].function.arguments) + assert args["command"] == "ls -la" + + def test_tool_call_with_surrounding_text(self, parser): + text = 'Let me check that for you.\n{"name": "terminal", "arguments": {"command": "pwd"}}' + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "terminal" + # Content should have the surrounding text + if content is not None: + assert "check that" in content or content.strip() != "" + + def test_multiple_tool_calls(self, parser): + text = ( + '{"name": "terminal", "arguments": {"command": "ls"}}\n' + '{"name": "read_file", "arguments": {"path": "test.py"}}' + ) + content, tool_calls = parser.parse(text) + assert tool_calls is not None + assert len(tool_calls) == 2 + names = {tc.function.name for tc in tool_calls} + assert "terminal" in names + assert "read_file" in names + + def test_tool_call_ids_are_unique(self, parser): + text = ( + '{"name": "terminal", "arguments": {"command": "ls"}}\n' + '{"name": "terminal", "arguments": {"command": "pwd"}}' + ) + _, tool_calls = parser.parse(text) + assert tool_calls is not None + ids = [tc.id for tc in tool_calls] + assert len(ids) == len(set(ids)), "Tool call IDs must be unique" + + def test_empty_string(self, parser): + content, tool_calls = parser.parse("") + assert tool_calls is None + + def test_malformed_json_in_tool_call(self, parser): + text = 'not valid json' + content, tool_calls = parser.parse(text) + # Should either return None tool_calls or handle gracefully + # (implementation may vary — some parsers return error tool calls) + + def test_truncated_tool_call(self, parser): + """Test handling of unclosed tool_call tag (model truncated mid-generation).""" + text = '{"name": "terminal", "arguments": {"command": "ls -la"}' + content, tool_calls = parser.parse(text) + # Parser should handle truncated output gracefully + # Either parse it successfully or return None + + +# ─── Parse result contract tests (applies to ALL parsers) ─────────────── + +class TestParseResultContract: + """Ensure all parsers conform to the ParseResult contract.""" + + @pytest.fixture(params=["hermes"]) # Add more as needed + def parser(self, request): + return get_parser(request.param) + + def test_returns_tuple_of_two(self, parser): + result = parser.parse("hello world") + assert isinstance(result, tuple) + assert len(result) == 2 + + def test_no_tools_returns_none_tool_calls(self, parser): + content, tool_calls = parser.parse("Just plain text, no tools.") + assert tool_calls is None + assert content is not None + + def test_tool_calls_are_proper_objects(self, parser): + """When tool calls are found, they should be ChatCompletionMessageToolCall objects.""" + # Use hermes format since that's universal + text = '{"name": "terminal", "arguments": {"command": "echo hi"}}' + content, tool_calls = parser.parse(text) + if tool_calls is not None: + for tc in tool_calls: + assert hasattr(tc, "id") + assert hasattr(tc, "function") + assert hasattr(tc.function, "name") + assert hasattr(tc.function, "arguments") + assert tc.id is not None + assert isinstance(tc.function.name, str) + assert isinstance(tc.function.arguments, str)