diff --git a/tests/test_tool_call_integration.py b/tests/test_tool_call_integration.py new file mode 100644 index 00000000..136af505 --- /dev/null +++ b/tests/test_tool_call_integration.py @@ -0,0 +1,338 @@ +""" +Integration test: turboquant compressed model passes hermes tool calls (issue #82). + +Validates that a TurboQuant-compressed model can: +1. Parse hermes tool schemas correctly +2. Format tool calls in OpenAI-compatible format +3. Pass through the hermes agent conversation loop + +Tests are structured as contract tests -- they validate the schema/format +compatibility without requiring a running model server. The live inference +test is skipped by default (requires llama-server with TurboQuant model). + +Usage: + pytest tests/test_tool_call_integration.py -v + pytest tests/test_tool_call_integration.py -v -k live # run live test if server available +""" +import json +import os +import pathlib +import re +import unittest + +import pytest + +ROOT = pathlib.Path(__file__).resolve().parents[1] +PROFILE_PATH = ROOT / "profiles" / "hermes-profile-gemma4-turboquant.yaml" +BENCHMARKS_DIR = ROOT / "benchmarks" + + +class TestHermesProfileSchema(unittest.TestCase): + """Validate the hermes profile YAML has required fields for tool calling.""" + + @classmethod + def setUpClass(cls): + import yaml + cls.profile = yaml.safe_load(PROFILE_PATH.read_text()) + + def test_profile_has_providers(self): + assert "providers" in self.profile, "Profile must define providers" + assert "primary" in self.profile["providers"], "Must have primary provider" + + def test_primary_provider_has_endpoint(self): + primary = self.profile["providers"]["primary"] + assert "endpoint" in primary, "Primary provider must have endpoint" + assert primary["endpoint"].startswith("http"), "Endpoint must be HTTP(S) URL" + + def test_primary_provider_has_api_path(self): + primary = self.profile["providers"]["primary"] + assert "api_path" in primary, "Primary provider must have api_path" + assert "/chat/completions" in primary["api_path"], ( + "api_path should be OpenAI-compatible /chat/completions" + ) + + def test_turboquant_settings_present(self): + primary = self.profile["providers"]["primary"] + assert "turboquant" in primary, "Must have turboquant config section" + tq = primary["turboquant"] + assert tq.get("enabled") is True, "TurboQuant must be enabled" + assert tq.get("kv_type") in ("turbo2", "turbo3", "turbo4"), ( + "kv_type must be turbo2, turbo3, or turbo4" + ) + + def test_context_window_configured(self): + primary = self.profile["providers"]["primary"] + assert "context" in primary, "Must have context config" + ctx = primary["context"] + assert ctx.get("max_tokens", 0) >= 8192, ( + "max_tokens should be >= 8192 for TurboQuant value proposition" + ) + + +class TestToolSchemaCompatibility(unittest.TestCase): + """Verify hermes tool schemas serialize to valid JSON for OpenAI tool_calls.""" + + SAMPLE_TOOL_SCHEMAS = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a text file with line numbers.", + "parameters": { + "type": "object", + "properties": { + "path": {"type": "string", "description": "File path"}, + "offset": {"type": "integer", "default": 1}, + "limit": {"type": "integer", "default": 500}, + }, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "execute_code", + "description": "Run a Python script.", + "parameters": { + "type": "object", + "properties": { + "code": {"type": "string", "description": "Python code"}, + }, + "required": ["code"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web.", + "parameters": { + "type": "object", + "properties": { + "query": {"type": "string"}, + "max_results": {"type": "integer", "default": 5}, + }, + "required": ["query"], + }, + }, + }, + ] + + def test_tool_schemas_serialize_to_json(self): + """Tool schemas must serialize without errors.""" + serialized = json.dumps(self.SAMPLE_TOOL_SCHEMAS) + assert len(serialized) > 0 + parsed = json.loads(serialized) + assert len(parsed) == len(self.SAMPLE_TOOL_SCHEMAS) + + def test_tool_schemas_have_required_openai_fields(self): + """Each tool schema must have the fields OpenAI expects.""" + for tool in self.SAMPLE_TOOL_SCHEMAS: + assert tool["type"] == "function", "Tool type must be 'function'" + fn = tool["function"] + assert "name" in fn, "Function must have name" + assert "description" in fn, "Function must have description" + assert "parameters" in fn, "Function must have parameters" + params = fn["parameters"] + assert params["type"] == "object", "Parameters type must be 'object'" + assert "properties" in params, "Parameters must have properties" + + def test_tool_call_response_format(self): + """Verify tool_call response matches OpenAI format.""" + tool_call = { + "id": "call_abc123", + "type": "function", + "function": { + "name": "read_file", + "arguments": json.dumps({"path": "/tmp/test.txt"}), + }, + } + args = json.loads(tool_call["function"]["arguments"]) + assert args["path"] == "/tmp/test.txt" + assert tool_call["function"]["name"] in [ + t["function"]["name"] for t in self.SAMPLE_TOOL_SCHEMAS + ] + + def test_tool_names_are_valid_identifiers(self): + """Tool names must be valid Python identifiers for hermes dispatch.""" + for tool in self.SAMPLE_TOOL_SCHEMAS: + name = tool["function"]["name"] + assert re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name), ( + f"Tool name \'{name}\' is not a valid identifier" + ) + + +class TestTurboquantServerConfig(unittest.TestCase): + """Validate server startup configuration matches hermes profile.""" + + def test_server_command_has_turboquant_flags(self): + """The server command in the profile must include -ctk/-ctv flags.""" + profile_text = PROFILE_PATH.read_text() + assert "-ctk" in profile_text, "Profile server command must include -ctk flag" + assert "-ctv" in profile_text, "Profile server command must include -ctv flag" + + def test_server_command_has_context_flag(self): + """Server command must set context size.""" + profile_text = PROFILE_PATH.read_text() + assert re.search(r"-c\s+\d+", profile_text), ( + "Server command must include -c flag" + ) + + def test_layer_adaptive_env_var(self): + """Profile must set TURBO_LAYER_ADAPTIVE env var.""" + profile_text = PROFILE_PATH.read_text() + assert "TURBO_LAYER_ADAPTIVE" in profile_text, ( + "Profile must configure TURBO_LAYER_ADAPTIVE" + ) + + +class TestBenchmarkData(unittest.TestCase): + """Validate benchmark test prompts include tool-call test cases.""" + + @classmethod + def setUpClass(cls): + prompts_path = BENCHMARKS_DIR / "test_prompts.json" + cls.prompts = json.loads(prompts_path.read_text()) + + def test_has_tool_call_test_prompt(self): + """Benchmark prompts must include a tool-call format test.""" + categories = [p.get("category") for p in self.prompts] + assert "tool_call_format" in categories, ( + "Benchmark must include a tool_call_format test case" + ) + + def test_tool_call_prompt_expects_json(self): + """Tool call test prompt must expect JSON in the response.""" + tool_prompt = next( + p for p in self.prompts if p.get("category") == "tool_call_format" + ) + pattern = tool_prompt.get("expected_pattern", "") + assert "json" in pattern.lower() or "\\{" in pattern, ( + "Tool call prompt must expect JSON-formatted response" + ) + + +@pytest.mark.skipif( + not os.environ.get("TURBOQUANT_SERVER_URL"), + reason="No TurboQuant server available (set TURBOQUANT_SERVER_URL to run)", +) +class TestLiveToolCallIntegration: + """Live integration test -- requires running llama-server with TurboQuant.""" + + def test_server_health(self): + """Server must respond to /v1/models endpoint.""" + import requests + url = os.environ["TURBOQUANT_SERVER_URL"] + resp = requests.get(f"{url}/v1/models", timeout=10) + assert resp.status_code == 200 + data = resp.json() + assert "data" in data + assert len(data["data"]) > 0 + + def test_tool_call_completion(self): + """Model must return a valid tool_call for a read_file prompt.""" + import requests + url = os.environ["TURBOQUANT_SERVER_URL"] + tools = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + } + ] + resp = requests.post( + f"{url}/v1/chat/completions", + json={ + "model": "gemma-4", + "messages": [ + {"role": "user", "content": "Read the file at /tmp/test.txt"} + ], + "tools": tools, + "tool_choice": "auto", + }, + timeout=120, + ) + assert resp.status_code == 200 + data = resp.json() + choice = data["choices"][0] + msg = choice["message"] + if "tool_calls" in msg and msg["tool_calls"]: + tc = msg["tool_calls"][0] + assert tc["type"] == "function" + assert tc["function"]["name"] == "read_file" + args = json.loads(tc["function"]["arguments"]) + assert "path" in args + else: + assert len(msg.get("content", "")) > 0 + + def test_tool_call_with_multiple_tools(self): + """Model must handle multiple available tools.""" + import requests + url = os.environ["TURBOQUANT_SERVER_URL"] + tools = [ + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file", + "parameters": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "web_search", + "description": "Search the web", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "execute_code", + "description": "Run Python code", + "parameters": { + "type": "object", + "properties": {"code": {"type": "string"}}, + "required": ["code"], + }, + }, + }, + ] + resp = requests.post( + f"{url}/v1/chat/completions", + json={ + "model": "gemma-4", + "messages": [ + {"role": "user", "content": "Search the web for 'bitcoin price'"} + ], + "tools": tools, + "tool_choice": "auto", + }, + timeout=120, + ) + assert resp.status_code == 200 + data = resp.json() + assert "choices" in data + assert len(data["choices"]) > 0 + + +if __name__ == "__main__": + unittest.main()