diff --git a/tests/test_agent_loop_vllm.py b/tests/test_agent_loop_vllm.py index b6ce5af09..d47478ecb 100644 --- a/tests/test_agent_loop_vllm.py +++ b/tests/test_agent_loop_vllm.py @@ -1,64 +1,359 @@ +"""Integration tests for HermesAgentLoop with a local vLLM server. + +Tests the full Phase 2 flow: ManagedServer + tool calling with a real +vLLM backend, producing actual token IDs and logprobs for RL training. + +Requires a running vLLM server. Start one from the atropos directory: + + python -m example_trainer.vllm_api_server \ + --model Qwen/Qwen3-4B-Thinking-2507 \ + --port 9001 \ + --gpu-memory-utilization 0.8 \ + --max-model-len=32000 + +Tests are automatically skipped if the server is not reachable. + +Run: + pytest tests/test_agent_loop_vllm.py -v + pytest tests/test_agent_loop_vllm.py -v -k "single" +""" + +import asyncio import json -from types import SimpleNamespace +import os +import sys +from pathlib import Path +from typing import Any, Dict +from unittest.mock import patch + +import pytest +import requests + +# Ensure repo root is importable +_repo_root = Path(__file__).resolve().parent.parent +if str(_repo_root) not in sys.path: + sys.path.insert(0, str(_repo_root)) + +try: + from environments.agent_loop import AgentResult, HermesAgentLoop +except ImportError: + pytest.skip("atroposlib not installed", allow_module_level=True) -def _tool_call(name: str, arguments): - return SimpleNamespace( - id="call_1", - type="function", - function=SimpleNamespace(name=name, arguments=arguments) +# ========================================================================= +# Configuration +# ========================================================================= + +VLLM_HOST = "localhost" +VLLM_PORT = 9001 +VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}" +VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507" + + +def _vllm_is_running() -> bool: + """Check if the vLLM server is reachable.""" + try: + r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3) + return r.status_code == 200 + except Exception: + return False + + +# Skip all tests in this module if vLLM is not running +pytestmark = pytest.mark.skipif( + not _vllm_is_running(), + reason=( + f"vLLM server not reachable at {VLLM_BASE_URL}. " + "Start it with: python -m example_trainer.vllm_api_server " + f"--model {VLLM_MODEL} --port {VLLM_PORT} " + "--gpu-memory-utilization 0.8 --max-model-len=32000" + ), +) + + +# ========================================================================= +# Server setup +# ========================================================================= + +def _make_server_manager(): + """Create a ServerManager pointing to the local vLLM server.""" + from atroposlib.envs.server_handling.server_manager import ( + ServerManager, + APIServerConfig, ) - -def _response_with_tool_call(arguments): - assistant = SimpleNamespace( - content=None, - reasoning=None, - tool_calls=[_tool_call("read_file", arguments)], + config = APIServerConfig( + base_url=VLLM_BASE_URL, + model_name=VLLM_MODEL, + server_type="vllm", + health_check=False, ) - choice = SimpleNamespace(message=assistant, finish_reason="tool_calls") - return SimpleNamespace(choices=[choice], usage=None) + sm = ServerManager([config], tool_parser="hermes") + sm.servers[0].server_healthy = True + return sm -class _FakeChatCompletions: - def __init__(self): - self.calls = 0 +def _get_tokenizer(): + """Load the tokenizer for the model.""" + from transformers import AutoTokenizer + return AutoTokenizer.from_pretrained(VLLM_MODEL) - def create(self, **kwargs): - self.calls += 1 - if self.calls == 1: - return _response_with_tool_call({"path": "README.md"}) - return SimpleNamespace( - choices=[SimpleNamespace(message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]), finish_reason="stop")], - usage=None, + +# ========================================================================= +# Fake tools +# ========================================================================= + +WEATHER_TOOL = { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather for a city. Returns temperature and conditions.", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name, e.g. 'Tokyo'", + } + }, + "required": ["city"], + }, + }, +} + +CALC_TOOL = { + "type": "function", + "function": { + "name": "calculate", + "description": "Calculate a math expression. Returns the numeric result.", + "parameters": { + "type": "object", + "properties": { + "expression": { + "type": "string", + "description": "Math expression, e.g. '2 + 3'", + } + }, + "required": ["expression"], + }, + }, +} + + +def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str: + """Handle fake tool calls for testing.""" + if tool_name == "get_weather": + city = args.get("city", "Unknown") + return json.dumps({ + "city": city, + "temperature": 22, + "conditions": "sunny", + "humidity": 45, + }) + elif tool_name == "calculate": + expr = args.get("expression", "0") + try: + result = eval(expr, {"__builtins__": {}}, {}) + return json.dumps({"result": result}) + except Exception as e: + return json.dumps({"error": str(e)}) + return json.dumps({"error": f"Unknown tool: {tool_name}"}) + + +# ========================================================================= +# Tests +# ========================================================================= + +@pytest.mark.asyncio +async def test_vllm_single_tool_call(): + """vLLM model calls a tool, gets result, responds — full Phase 2 flow.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, ) + messages = [ + {"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."}, + ] -class _FakeClient: - def __init__(self): - self.chat = SimpleNamespace(completions=_FakeChatCompletions()) + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + assert isinstance(result, AgentResult) + assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}" + + # Verify tool call happened + tool_calls_found = False + for msg in result.messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + if tc["function"]["name"] == "get_weather": + tool_calls_found = True + args = json.loads(tc["function"]["arguments"]) + assert "city" in args + assert tool_calls_found, "Model should have called get_weather" + + # Verify tool results in conversation + tool_results = [m for m in result.messages if m.get("role") == "tool"] + assert len(tool_results) >= 1 -def test_tool_call_validation_accepts_dict_arguments(monkeypatch): - from run_agent import AIAgent +@pytest.mark.asyncio +async def test_vllm_multi_tool_calls(): + """vLLM model calls multiple tools across turns.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() - monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient()) - monkeypatch.setattr("run_agent.get_tool_definitions", lambda *args, **kwargs: [{"function": {"name": "read_file"}}]) - monkeypatch.setattr( - "run_agent.handle_function_call", - lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}), + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL, CALC_TOOL], + valid_tool_names={"get_weather", "calculate"}, + max_turns=10, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": ( + "I need two things: " + "1) What's the weather in Paris? Use get_weather. " + "2) What is 15 * 7? Use calculate." + )}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Both tools should be called + tools_called = set() + for msg in result.messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tools_called.add(tc["function"]["name"]) + + assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}" + assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}" + + +@pytest.mark.asyncio +async def test_vllm_managed_server_produces_nodes(): + """ManagedServer should produce SequenceNodes with tokens and logprobs.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": "What's the weather in Berlin? Use get_weather."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Get the managed state — should have SequenceNodes + state = managed.get_state() + + assert state is not None, "ManagedServer should return state" + nodes = state.get("nodes", []) + assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}" + + node = nodes[0] + assert hasattr(node, "tokens"), "Node should have tokens" + assert hasattr(node, "logprobs"), "Node should have logprobs" + assert len(node.tokens) > 0, "Tokens should not be empty" + assert len(node.logprobs) > 0, "Logprobs should not be empty" + assert len(node.tokens) == len(node.logprobs), ( + f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length" ) - agent = AIAgent( - model="test-model", - api_key="test-key", - base_url="http://localhost:8080/v1", - platform="cli", - max_iterations=3, - quiet_mode=True, - skip_memory=True, + +@pytest.mark.asyncio +async def test_vllm_no_tools_direct_response(): + """vLLM model should respond directly when no tools are needed.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server(tokenizer=tokenizer) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[WEATHER_TOOL], + valid_tool_names={"get_weather"}, + max_turns=5, + temperature=0.6, + max_tokens=500, + ) + + messages = [ + {"role": "user", "content": "What is 2 + 2? Answer directly, no tools."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + assert result.finished_naturally, "Should finish naturally" + assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}" + + final = result.messages[-1] + assert final["role"] == "assistant" + assert final["content"], "Should have content" + + +@pytest.mark.asyncio +async def test_vllm_thinking_content_extracted(): + """Qwen3-Thinking model should produce reasoning content.""" + sm = _make_server_manager() + tokenizer = _get_tokenizer() + + async with sm.managed_server( + tokenizer=tokenizer, + preserve_think_blocks=True, + ) as managed: + agent = HermesAgentLoop( + server=managed, + tool_schemas=[CALC_TOOL], + valid_tool_names={"calculate"}, + max_turns=5, + temperature=0.6, + max_tokens=1000, + ) + + messages = [ + {"role": "user", "content": "What is 123 * 456? Use the calculate tool."}, + ] + + with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): + result = await agent.run(messages) + + # Qwen3-Thinking should generate blocks + # Check if any content contains thinking markers + has_thinking = False + for msg in result.messages: + content = msg.get("content", "") or "" + if "" in content or "" in content: + has_thinking = True + break + + # Also check reasoning_per_turn + has_reasoning = any(r for r in result.reasoning_per_turn if r) + + # At least one of these should be true for a thinking model + assert has_thinking or has_reasoning, ( + "Qwen3-Thinking should produce blocks or reasoning content" ) - - result = agent.run_conversation("read the file") - - assert result["final_response"] == "done" diff --git a/tests/test_dict_tool_call_args.py b/tests/test_dict_tool_call_args.py new file mode 100644 index 000000000..e8b4d70fa --- /dev/null +++ b/tests/test_dict_tool_call_args.py @@ -0,0 +1,72 @@ +import json +from types import SimpleNamespace + + +def _tool_call(name: str, arguments): + return SimpleNamespace( + id="call_1", + type="function", + function=SimpleNamespace(name=name, arguments=arguments), + ) + + +def _response_with_tool_call(arguments): + assistant = SimpleNamespace( + content=None, + reasoning=None, + tool_calls=[_tool_call("read_file", arguments)], + ) + choice = SimpleNamespace(message=assistant, finish_reason="tool_calls") + return SimpleNamespace(choices=[choice], usage=None) + + +class _FakeChatCompletions: + def __init__(self): + self.calls = 0 + + def create(self, **kwargs): + self.calls += 1 + if self.calls == 1: + return _response_with_tool_call({"path": "README.md"}) + return SimpleNamespace( + choices=[ + SimpleNamespace( + message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]), + finish_reason="stop", + ) + ], + usage=None, + ) + + +class _FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=_FakeChatCompletions()) + + +def test_tool_call_validation_accepts_dict_arguments(monkeypatch): + from run_agent import AIAgent + + monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient()) + monkeypatch.setattr( + "run_agent.get_tool_definitions", + lambda *args, **kwargs: [{"function": {"name": "read_file"}}], + ) + monkeypatch.setattr( + "run_agent.handle_function_call", + lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}), + ) + + agent = AIAgent( + model="test-model", + api_key="test-key", + base_url="http://localhost:8080/v1", + platform="cli", + max_iterations=3, + quiet_mode=True, + skip_memory=True, + ) + + result = agent.run_conversation("read the file") + + assert result["final_response"] == "done"