From 1f9e7cd65989e4c26092d55747fe751c5e6f94bb Mon Sep 17 00:00:00 2001 From: dmahan93 Date: Mon, 9 Mar 2026 21:18:42 -0500 Subject: [PATCH] test: 5 vLLM integration tests + fallback tool call parser Tests hit a real vLLM server (Qwen/Qwen3-4B-Thinking-2507) via ManagedServer Phase 2. Auto-skip if server isn't running. Tests verify: - Single tool call through full agent loop - Multi-tool calls across turns - ManagedServer produces SequenceNodes with tokens/logprobs - Direct response without tools - Thinking model produces blocks Also adds fallback parser in agent_loop.py: when ManagedServer's ToolCallTranslator can't parse (vLLM not installed), hermes-agent's standalone parsers extract tags from raw content. --- environments/agent_loop.py | 30 ++- tests/test_agent_loop_vllm.py | 356 ++++++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+), 1 deletion(-) create mode 100644 tests/test_agent_loop_vllm.py diff --git a/environments/agent_loop.py b/environments/agent_loop.py index ce2b1f9b3..b2d29c6f9 100644 --- a/environments/agent_loop.py +++ b/environments/agent_loop.py @@ -249,7 +249,35 @@ class HermesAgentLoop: reasoning = _extract_reasoning_from_message(assistant_msg) reasoning_per_turn.append(reasoning) - # Check for tool calls -- standard OpenAI spec + # Check for tool calls -- standard OpenAI spec. + # Fallback: if response has no structured tool_calls but content + # contains raw tool call tags (e.g. ), parse them using + # hermes-agent's standalone parsers. This handles the case where + # ManagedServer's ToolCallTranslator couldn't parse because vLLM + # isn't installed. + if ( + not assistant_msg.tool_calls + and assistant_msg.content + and self.tool_schemas + and "" in (assistant_msg.content or "") + ): + try: + from environments.tool_call_parsers import get_parser + fallback_parser = get_parser("hermes") + parsed_content, parsed_calls = fallback_parser.parse( + assistant_msg.content + ) + if parsed_calls: + assistant_msg.tool_calls = parsed_calls + if parsed_content is not None: + assistant_msg.content = parsed_content + logger.debug( + "Fallback parser extracted %d tool calls from raw content", + len(parsed_calls), + ) + except Exception: + pass # Fall through to no tool calls + if assistant_msg.tool_calls: # Build the assistant message dict for conversation history msg_dict: Dict[str, Any] = { diff --git a/tests/test_agent_loop_vllm.py b/tests/test_agent_loop_vllm.py new file mode 100644 index 000000000..1a21d440c --- /dev/null +++ b/tests/test_agent_loop_vllm.py @@ -0,0 +1,356 @@ +"""Integration tests for HermesAgentLoop with a local vLLM server. + +Tests the full Phase 2 flow: ManagedServer + tool calling with a real +vLLM backend, producing actual token IDs and logprobs for RL training. + +Requires a running vLLM server. Start one from the atropos directory: + + python -m example_trainer.vllm_api_server \ + --model Qwen/Qwen3-4B-Thinking-2507 \ + --port 9001 \ + --gpu-memory-utilization 0.8 \ + --max-model-len=32000 + +Tests are automatically skipped if the server is not reachable. + +Run: + pytest tests/test_agent_loop_vllm.py -v + pytest tests/test_agent_loop_vllm.py -v -k "single" +""" + +import asyncio +import json +import os +import sys +from pathlib import Path +from typing import Any, Dict +from unittest.mock import patch + +import pytest +import requests + +# Ensure repo root is importable +_repo_root = Path(__file__).resolve().parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +from environments.agent_loop import AgentResult, HermesAgentLoop + + +# ========================================================================= +# Configuration +# ========================================================================= + +VLLM_HOST = "localhost" +VLLM_PORT = 9001 +VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}" +VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507" + + +def _vllm_is_running() -> bool: + """Check if the vLLM server is reachable.""" + try: + r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3) + return r.status_code == 200 + except Exception: + return False + + +# Skip all tests in this module if vLLM is not running +pytestmark = pytest.mark.skipif( + not _vllm_is_running(), + reason=( + f"vLLM server not reachable at {VLLM_BASE_URL}. " + "Start it with: python -m example_trainer.vllm_api_server " + f"--model {VLLM_MODEL} --port {VLLM_PORT} " + "--gpu-memory-utilization 0.8 --max-model-len=32000" + ), +) + + +# ========================================================================= +# Server setup +# ========================================================================= + +def _make_server_manager(): + """Create a ServerManager pointing to the local vLLM server.""" + from atroposlib.envs.server_handling.server_manager import ( + ServerManager, + APIServerConfig, + ) + + config = APIServerConfig( + base_url=VLLM_BASE_URL, + model_name=VLLM_MODEL, + server_type="vllm", + health_check=False, + ) + sm = ServerManager([config], tool_parser="hermes") + sm.servers[0].server_healthy = True + return sm + + +def _get_tokenizer(): + """Load the tokenizer for the model.""" + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained(VLLM_MODEL) + + +# ========================================================================= +# Fake tools +# ========================================================================= + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city. Returns temperature and conditions.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name, e.g. 'Tokyo'", + } + }, + "required": ["city"], + }, + }, +} + +CALC_TOOL = { + "type": "function", + "function": { + "name": "calculate", + "description": "Calculate a math expression. Returns the numeric result.", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression, e.g. '2 + 3'", + } + }, + "required": ["expression"], + }, + }, +} + + +def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str: + """Handle fake tool calls for testing.""" + if tool_name == "get_weather": + city = args.get("city", "Unknown") + return json.dumps({ + "city": city, + "temperature": 22, + "conditions": "sunny", + "humidity": 45, + }) + elif tool_name == "calculate": + expr = args.get("expression", "0") + try: + result = eval(expr, {"__builtins__": {}}, {}) + return json.dumps({"result": result}) + except Exception as e: + return json.dumps({"error": str(e)}) + return json.dumps({"error": f"Unknown tool: {tool_name}"}) + + +# ========================================================================= +# Tests +# ========================================================================= + +@pytest.mark.asyncio +async def test_vllm_single_tool_call(): + """vLLM model calls a tool, gets result, responds — full Phase 2 flow.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + assert isinstance(result, AgentResult) + assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}" + + # Verify tool call happened + tool_calls_found = False + for msg in result.messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + if tc["function"]["name"] == "get_weather": + tool_calls_found = True + args = json.loads(tc["function"]["arguments"]) + assert "city" in args + assert tool_calls_found, "Model should have called get_weather" + + # Verify tool results in conversation + tool_results = [m for m in result.messages if m.get("role") == "tool"] + assert len(tool_results) >= 1 + + +@pytest.mark.asyncio +async def test_vllm_multi_tool_calls(): + """vLLM model calls multiple tools across turns.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL, CALC_TOOL], + valid_tool_names={"get_weather", "calculate"}, + max_turns=10, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": ( + "I need two things: " + "1) What's the weather in Paris? Use get_weather. " + "2) What is 15 * 7? Use calculate." + )}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Both tools should be called + tools_called = set() + for msg in result.messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tools_called.add(tc["function"]["name"]) + + assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}" + assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}" + + +@pytest.mark.asyncio +async def test_vllm_managed_server_produces_nodes(): + """ManagedServer should produce SequenceNodes with tokens and logprobs.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": "What's the weather in Berlin? Use get_weather."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Get the managed state — should have SequenceNodes + state = managed.get_state() + + assert state is not None, "ManagedServer should return state" + nodes = state.get("nodes", []) + assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}" + + node = nodes[0] + assert hasattr(node, "tokens"), "Node should have tokens" + assert hasattr(node, "logprobs"), "Node should have logprobs" + assert len(node.tokens) > 0, "Tokens should not be empty" + assert len(node.logprobs) > 0, "Logprobs should not be empty" + assert len(node.tokens) == len(node.logprobs), ( + f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length" + ) + + +@pytest.mark.asyncio +async def test_vllm_no_tools_direct_response(): + """vLLM model should respond directly when no tools are needed.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=500, + ) + + messages = [ + {"role": "user", "content": "What is 2 + 2? Answer directly, no tools."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + assert result.finished_naturally, "Should finish naturally" + assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}" + + final = result.messages[-1] + assert final["role"] == "assistant" + assert final["content"], "Should have content" + + +@pytest.mark.asyncio +async def test_vllm_thinking_content_extracted(): + """Qwen3-Thinking model should produce reasoning content.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server( + tokenizer=tokenizer, + preserve_think_blocks=True, + ) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[CALC_TOOL], + valid_tool_names={"calculate"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": "What is 123 * 456? Use the calculate tool."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Qwen3-Thinking should generate blocks + # Check if any content contains thinking markers + has_thinking = False + for msg in result.messages: + content = msg.get("content", "") or "" + if "" in content or "" in content: + has_thinking = True + break + + # Also check reasoning_per_turn + has_reasoning = any(r for r in result.reasoning_per_turn if r) + + # At least one of these should be true for a thinking model + assert has_thinking or has_reasoning, ( + "Qwen3-Thinking should produce blocks or reasoning content" + )