diff --git a/run_agent.py b/run_agent.py index e696ded02..29f01120f 100644 --- a/run_agent.py +++ b/run_agent.py @@ -5582,6 +5582,12 @@ class AIAgent: invalid_json_args = [] for tc in assistant_message.tool_calls: args = tc.function.arguments + if isinstance(args, (dict, list)): + tc.function.arguments = json.dumps(args) + continue + if args is not None and not isinstance(args, str): + tc.function.arguments = str(args) + args = tc.function.arguments # Treat empty/whitespace strings as empty object if not args or not args.strip(): tc.function.arguments = "{}" diff --git a/tests/test_agent_loop_vllm.py b/tests/test_agent_loop_vllm.py index d47478ecb..b6ce5af09 100644 --- a/tests/test_agent_loop_vllm.py +++ b/tests/test_agent_loop_vllm.py @@ -1,359 +1,64 @@ -"""Integration tests for HermesAgentLoop with a local vLLM server. - -Tests the full Phase 2 flow: ManagedServer + tool calling with a real -vLLM backend, producing actual token IDs and logprobs for RL training. - -Requires a running vLLM server. Start one from the atropos directory: - - python -m example_trainer.vllm_api_server \ - --model Qwen/Qwen3-4B-Thinking-2507 \ - --port 9001 \ - --gpu-memory-utilization 0.8 \ - --max-model-len=32000 - -Tests are automatically skipped if the server is not reachable. - -Run: - pytest tests/test_agent_loop_vllm.py -v - pytest tests/test_agent_loop_vllm.py -v -k "single" -""" - -import asyncio import json -import os -import sys -from pathlib import Path -from typing import Any, Dict -from unittest.mock import patch - -import pytest -import requests - -# Ensure repo root is importable -_repo_root = Path(__file__).resolve().parent.parent -if str(_repo_root) not in sys.path: - sys.path.insert(0, str(_repo_root)) - -try: - from environments.agent_loop import AgentResult, HermesAgentLoop -except ImportError: - pytest.skip("atroposlib not installed", allow_module_level=True) +from types import SimpleNamespace -# ========================================================================= -# Configuration -# ========================================================================= - -VLLM_HOST = "localhost" -VLLM_PORT = 9001 -VLLM_BASE_URL = f"http://{VLLM_HOST}:{VLLM_PORT}" -VLLM_MODEL = "Qwen/Qwen3-4B-Thinking-2507" - - -def _vllm_is_running() -> bool: - """Check if the vLLM server is reachable.""" - try: - r = requests.get(f"{VLLM_BASE_URL}/health", timeout=3) - return r.status_code == 200 - except Exception: - return False - - -# Skip all tests in this module if vLLM is not running -pytestmark = pytest.mark.skipif( - not _vllm_is_running(), - reason=( - f"vLLM server not reachable at {VLLM_BASE_URL}. " - "Start it with: python -m example_trainer.vllm_api_server " - f"--model {VLLM_MODEL} --port {VLLM_PORT} " - "--gpu-memory-utilization 0.8 --max-model-len=32000" - ), -) - - -# ========================================================================= -# Server setup -# ========================================================================= - -def _make_server_manager(): - """Create a ServerManager pointing to the local vLLM server.""" - from atroposlib.envs.server_handling.server_manager import ( - ServerManager, - APIServerConfig, - ) - - config = APIServerConfig( - base_url=VLLM_BASE_URL, - model_name=VLLM_MODEL, - server_type="vllm", - health_check=False, - ) - sm = ServerManager([config], tool_parser="hermes") - sm.servers[0].server_healthy = True - return sm - - -def _get_tokenizer(): - """Load the tokenizer for the model.""" - from transformers import AutoTokenizer - return AutoTokenizer.from_pretrained(VLLM_MODEL) - - -# ========================================================================= -# Fake tools -# ========================================================================= - -WEATHER_TOOL = { - "type": "function", - "function": { - "name": "get_weather", - "description": "Get the current weather for a city. Returns temperature and conditions.", - "parameters": { - "type": "object", - "properties": { - "city": { - "type": "string", - "description": "City name, e.g. 'Tokyo'", - } - }, - "required": ["city"], - }, - }, -} - -CALC_TOOL = { - "type": "function", - "function": { - "name": "calculate", - "description": "Calculate a math expression. Returns the numeric result.", - "parameters": { - "type": "object", - "properties": { - "expression": { - "type": "string", - "description": "Math expression, e.g. '2 + 3'", - } - }, - "required": ["expression"], - }, - }, -} - - -def _fake_tool_handler(tool_name: str, args: Dict[str, Any], **kwargs) -> str: - """Handle fake tool calls for testing.""" - if tool_name == "get_weather": - city = args.get("city", "Unknown") - return json.dumps({ - "city": city, - "temperature": 22, - "conditions": "sunny", - "humidity": 45, - }) - elif tool_name == "calculate": - expr = args.get("expression", "0") - try: - result = eval(expr, {"__builtins__": {}}, {}) - return json.dumps({"result": result}) - except Exception as e: - return json.dumps({"error": str(e)}) - return json.dumps({"error": f"Unknown tool: {tool_name}"}) - - -# ========================================================================= -# Tests -# ========================================================================= - -@pytest.mark.asyncio -async def test_vllm_single_tool_call(): - """vLLM model calls a tool, gets result, responds — full Phase 2 flow.""" - sm = _make_server_manager() - tokenizer = _get_tokenizer() - - async with sm.managed_server(tokenizer=tokenizer) as managed: - agent = HermesAgentLoop( - server=managed, - tool_schemas=[WEATHER_TOOL], - valid_tool_names={"get_weather"}, - max_turns=5, - temperature=0.6, - max_tokens=1000, - ) - - messages = [ - {"role": "user", "content": "What's the weather in Tokyo? Use the get_weather tool."}, - ] - - with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): - result = await agent.run(messages) - - assert isinstance(result, AgentResult) - assert result.turns_used >= 2, f"Expected at least 2 turns, got {result.turns_used}" - - # Verify tool call happened - tool_calls_found = False - for msg in result.messages: - if msg.get("role") == "assistant" and msg.get("tool_calls"): - for tc in msg["tool_calls"]: - if tc["function"]["name"] == "get_weather": - tool_calls_found = True - args = json.loads(tc["function"]["arguments"]) - assert "city" in args - assert tool_calls_found, "Model should have called get_weather" - - # Verify tool results in conversation - tool_results = [m for m in result.messages if m.get("role") == "tool"] - assert len(tool_results) >= 1 - - -@pytest.mark.asyncio -async def test_vllm_multi_tool_calls(): - """vLLM model calls multiple tools across turns.""" - sm = _make_server_manager() - tokenizer = _get_tokenizer() - - async with sm.managed_server(tokenizer=tokenizer) as managed: - agent = HermesAgentLoop( - server=managed, - tool_schemas=[WEATHER_TOOL, CALC_TOOL], - valid_tool_names={"get_weather", "calculate"}, - max_turns=10, - temperature=0.6, - max_tokens=1000, - ) - - messages = [ - {"role": "user", "content": ( - "I need two things: " - "1) What's the weather in Paris? Use get_weather. " - "2) What is 15 * 7? Use calculate." - )}, - ] - - with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): - result = await agent.run(messages) - - # Both tools should be called - tools_called = set() - for msg in result.messages: - if msg.get("role") == "assistant" and msg.get("tool_calls"): - for tc in msg["tool_calls"]: - tools_called.add(tc["function"]["name"]) - - assert "get_weather" in tools_called, f"get_weather not called. Called: {tools_called}" - assert "calculate" in tools_called, f"calculate not called. Called: {tools_called}" - - -@pytest.mark.asyncio -async def test_vllm_managed_server_produces_nodes(): - """ManagedServer should produce SequenceNodes with tokens and logprobs.""" - sm = _make_server_manager() - tokenizer = _get_tokenizer() - - async with sm.managed_server(tokenizer=tokenizer) as managed: - agent = HermesAgentLoop( - server=managed, - tool_schemas=[WEATHER_TOOL], - valid_tool_names={"get_weather"}, - max_turns=5, - temperature=0.6, - max_tokens=1000, - ) - - messages = [ - {"role": "user", "content": "What's the weather in Berlin? Use get_weather."}, - ] - - with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): - result = await agent.run(messages) - - # Get the managed state — should have SequenceNodes - state = managed.get_state() - - assert state is not None, "ManagedServer should return state" - nodes = state.get("nodes", []) - assert len(nodes) >= 1, f"Should have at least 1 node, got {len(nodes)}" - - node = nodes[0] - assert hasattr(node, "tokens"), "Node should have tokens" - assert hasattr(node, "logprobs"), "Node should have logprobs" - assert len(node.tokens) > 0, "Tokens should not be empty" - assert len(node.logprobs) > 0, "Logprobs should not be empty" - assert len(node.tokens) == len(node.logprobs), ( - f"Tokens ({len(node.tokens)}) and logprobs ({len(node.logprobs)}) should have same length" +def _tool_call(name: str, arguments): + return SimpleNamespace( + id="call_1", + type="function", + function=SimpleNamespace(name=name, arguments=arguments) ) -@pytest.mark.asyncio -async def test_vllm_no_tools_direct_response(): - """vLLM model should respond directly when no tools are needed.""" - sm = _make_server_manager() - tokenizer = _get_tokenizer() - - async with sm.managed_server(tokenizer=tokenizer) as managed: - agent = HermesAgentLoop( - server=managed, - tool_schemas=[WEATHER_TOOL], - valid_tool_names={"get_weather"}, - max_turns=5, - temperature=0.6, - max_tokens=500, - ) - - messages = [ - {"role": "user", "content": "What is 2 + 2? Answer directly, no tools."}, - ] - - with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): - result = await agent.run(messages) - - assert result.finished_naturally, "Should finish naturally" - assert result.turns_used == 1, f"Should take 1 turn, took {result.turns_used}" - - final = result.messages[-1] - assert final["role"] == "assistant" - assert final["content"], "Should have content" - - -@pytest.mark.asyncio -async def test_vllm_thinking_content_extracted(): - """Qwen3-Thinking model should produce reasoning content.""" - sm = _make_server_manager() - tokenizer = _get_tokenizer() - - async with sm.managed_server( - tokenizer=tokenizer, - preserve_think_blocks=True, - ) as managed: - agent = HermesAgentLoop( - server=managed, - tool_schemas=[CALC_TOOL], - valid_tool_names={"calculate"}, - max_turns=5, - temperature=0.6, - max_tokens=1000, - ) - - messages = [ - {"role": "user", "content": "What is 123 * 456? Use the calculate tool."}, - ] - - with patch("environments.agent_loop.handle_function_call", side_effect=_fake_tool_handler): - result = await agent.run(messages) - - # Qwen3-Thinking should generate blocks - # Check if any content contains thinking markers - has_thinking = False - for msg in result.messages: - content = msg.get("content", "") or "" - if "" in content or "" in content: - has_thinking = True - break - - # Also check reasoning_per_turn - has_reasoning = any(r for r in result.reasoning_per_turn if r) - - # At least one of these should be true for a thinking model - assert has_thinking or has_reasoning, ( - "Qwen3-Thinking should produce blocks or reasoning content" +def _response_with_tool_call(arguments): + assistant = SimpleNamespace( + content=None, + reasoning=None, + tool_calls=[_tool_call("read_file", arguments)], ) + choice = SimpleNamespace(message=assistant, finish_reason="tool_calls") + return SimpleNamespace(choices=[choice], usage=None) + + +class _FakeChatCompletions: + def __init__(self): + self.calls = 0 + + def create(self, **kwargs): + self.calls += 1 + if self.calls == 1: + return _response_with_tool_call({"path": "README.md"}) + return SimpleNamespace( + choices=[SimpleNamespace(message=SimpleNamespace(content="done", reasoning=None, tool_calls=[]), finish_reason="stop")], + usage=None, + ) + + +class _FakeClient: + def __init__(self): + self.chat = SimpleNamespace(completions=_FakeChatCompletions()) + + +def test_tool_call_validation_accepts_dict_arguments(monkeypatch): + from run_agent import AIAgent + + monkeypatch.setattr("run_agent.OpenAI", lambda **kwargs: _FakeClient()) + monkeypatch.setattr("run_agent.get_tool_definitions", lambda *args, **kwargs: [{"function": {"name": "read_file"}}]) + monkeypatch.setattr( + "run_agent.handle_function_call", + lambda name, args, task_id=None, **kwargs: json.dumps({"ok": True, "args": args}), + ) + + agent = AIAgent( + model="test-model", + api_key="test-key", + base_url="http://localhost:8080/v1", + platform="cli", + max_iterations=3, + quiet_mode=True, + skip_memory=True, + ) + + result = agent.run_conversation("read the file") + + assert result["final_response"] == "done"