All checks were successful
Smoke Test / smoke (pull_request) Successful in 11s
339 lines
12 KiB
Python
339 lines
12 KiB
Python
"""
|
|
Integration test: turboquant compressed model passes hermes tool calls (issue #82).
|
|
|
|
Validates that a TurboQuant-compressed model can:
|
|
1. Parse hermes tool schemas correctly
|
|
2. Format tool calls in OpenAI-compatible format
|
|
3. Pass through the hermes agent conversation loop
|
|
|
|
Tests are structured as contract tests -- they validate the schema/format
|
|
compatibility without requiring a running model server. The live inference
|
|
test is skipped by default (requires llama-server with TurboQuant model).
|
|
|
|
Usage:
|
|
pytest tests/test_tool_call_integration.py -v
|
|
pytest tests/test_tool_call_integration.py -v -k live # run live test if server available
|
|
"""
|
|
import json
|
|
import os
|
|
import pathlib
|
|
import re
|
|
import unittest
|
|
|
|
import pytest
|
|
|
|
ROOT = pathlib.Path(__file__).resolve().parents[1]
|
|
PROFILE_PATH = ROOT / "profiles" / "hermes-profile-gemma4-turboquant.yaml"
|
|
BENCHMARKS_DIR = ROOT / "benchmarks"
|
|
|
|
|
|
class TestHermesProfileSchema(unittest.TestCase):
|
|
"""Validate the hermes profile YAML has required fields for tool calling."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
import yaml
|
|
cls.profile = yaml.safe_load(PROFILE_PATH.read_text())
|
|
|
|
def test_profile_has_providers(self):
|
|
assert "providers" in self.profile, "Profile must define providers"
|
|
assert "primary" in self.profile["providers"], "Must have primary provider"
|
|
|
|
def test_primary_provider_has_endpoint(self):
|
|
primary = self.profile["providers"]["primary"]
|
|
assert "endpoint" in primary, "Primary provider must have endpoint"
|
|
assert primary["endpoint"].startswith("http"), "Endpoint must be HTTP(S) URL"
|
|
|
|
def test_primary_provider_has_api_path(self):
|
|
primary = self.profile["providers"]["primary"]
|
|
assert "api_path" in primary, "Primary provider must have api_path"
|
|
assert "/chat/completions" in primary["api_path"], (
|
|
"api_path should be OpenAI-compatible /chat/completions"
|
|
)
|
|
|
|
def test_turboquant_settings_present(self):
|
|
primary = self.profile["providers"]["primary"]
|
|
assert "turboquant" in primary, "Must have turboquant config section"
|
|
tq = primary["turboquant"]
|
|
assert tq.get("enabled") is True, "TurboQuant must be enabled"
|
|
assert tq.get("kv_type") in ("turbo2", "turbo3", "turbo4"), (
|
|
"kv_type must be turbo2, turbo3, or turbo4"
|
|
)
|
|
|
|
def test_context_window_configured(self):
|
|
primary = self.profile["providers"]["primary"]
|
|
assert "context" in primary, "Must have context config"
|
|
ctx = primary["context"]
|
|
assert ctx.get("max_tokens", 0) >= 8192, (
|
|
"max_tokens should be >= 8192 for TurboQuant value proposition"
|
|
)
|
|
|
|
|
|
class TestToolSchemaCompatibility(unittest.TestCase):
|
|
"""Verify hermes tool schemas serialize to valid JSON for OpenAI tool_calls."""
|
|
|
|
SAMPLE_TOOL_SCHEMAS = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"description": "Read a text file with line numbers.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"path": {"type": "string", "description": "File path"},
|
|
"offset": {"type": "integer", "default": 1},
|
|
"limit": {"type": "integer", "default": 500},
|
|
},
|
|
"required": ["path"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "execute_code",
|
|
"description": "Run a Python script.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"code": {"type": "string", "description": "Python code"},
|
|
},
|
|
"required": ["code"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "web_search",
|
|
"description": "Search the web.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"query": {"type": "string"},
|
|
"max_results": {"type": "integer", "default": 5},
|
|
},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
def test_tool_schemas_serialize_to_json(self):
|
|
"""Tool schemas must serialize without errors."""
|
|
serialized = json.dumps(self.SAMPLE_TOOL_SCHEMAS)
|
|
assert len(serialized) > 0
|
|
parsed = json.loads(serialized)
|
|
assert len(parsed) == len(self.SAMPLE_TOOL_SCHEMAS)
|
|
|
|
def test_tool_schemas_have_required_openai_fields(self):
|
|
"""Each tool schema must have the fields OpenAI expects."""
|
|
for tool in self.SAMPLE_TOOL_SCHEMAS:
|
|
assert tool["type"] == "function", "Tool type must be 'function'"
|
|
fn = tool["function"]
|
|
assert "name" in fn, "Function must have name"
|
|
assert "description" in fn, "Function must have description"
|
|
assert "parameters" in fn, "Function must have parameters"
|
|
params = fn["parameters"]
|
|
assert params["type"] == "object", "Parameters type must be 'object'"
|
|
assert "properties" in params, "Parameters must have properties"
|
|
|
|
def test_tool_call_response_format(self):
|
|
"""Verify tool_call response matches OpenAI format."""
|
|
tool_call = {
|
|
"id": "call_abc123",
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"arguments": json.dumps({"path": "/tmp/test.txt"}),
|
|
},
|
|
}
|
|
args = json.loads(tool_call["function"]["arguments"])
|
|
assert args["path"] == "/tmp/test.txt"
|
|
assert tool_call["function"]["name"] in [
|
|
t["function"]["name"] for t in self.SAMPLE_TOOL_SCHEMAS
|
|
]
|
|
|
|
def test_tool_names_are_valid_identifiers(self):
|
|
"""Tool names must be valid Python identifiers for hermes dispatch."""
|
|
for tool in self.SAMPLE_TOOL_SCHEMAS:
|
|
name = tool["function"]["name"]
|
|
assert re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name), (
|
|
f"Tool name \'{name}\' is not a valid identifier"
|
|
)
|
|
|
|
|
|
class TestTurboquantServerConfig(unittest.TestCase):
|
|
"""Validate server startup configuration matches hermes profile."""
|
|
|
|
def test_server_command_has_turboquant_flags(self):
|
|
"""The server command in the profile must include -ctk/-ctv flags."""
|
|
profile_text = PROFILE_PATH.read_text()
|
|
assert "-ctk" in profile_text, "Profile server command must include -ctk flag"
|
|
assert "-ctv" in profile_text, "Profile server command must include -ctv flag"
|
|
|
|
def test_server_command_has_context_flag(self):
|
|
"""Server command must set context size."""
|
|
profile_text = PROFILE_PATH.read_text()
|
|
assert re.search(r"-c\s+\d+", profile_text), (
|
|
"Server command must include -c <context_size> flag"
|
|
)
|
|
|
|
def test_layer_adaptive_env_var(self):
|
|
"""Profile must set TURBO_LAYER_ADAPTIVE env var."""
|
|
profile_text = PROFILE_PATH.read_text()
|
|
assert "TURBO_LAYER_ADAPTIVE" in profile_text, (
|
|
"Profile must configure TURBO_LAYER_ADAPTIVE"
|
|
)
|
|
|
|
|
|
class TestBenchmarkData(unittest.TestCase):
|
|
"""Validate benchmark test prompts include tool-call test cases."""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
prompts_path = BENCHMARKS_DIR / "test_prompts.json"
|
|
cls.prompts = json.loads(prompts_path.read_text())
|
|
|
|
def test_has_tool_call_test_prompt(self):
|
|
"""Benchmark prompts must include a tool-call format test."""
|
|
categories = [p.get("category") for p in self.prompts]
|
|
assert "tool_call_format" in categories, (
|
|
"Benchmark must include a tool_call_format test case"
|
|
)
|
|
|
|
def test_tool_call_prompt_expects_json(self):
|
|
"""Tool call test prompt must expect JSON in the response."""
|
|
tool_prompt = next(
|
|
p for p in self.prompts if p.get("category") == "tool_call_format"
|
|
)
|
|
pattern = tool_prompt.get("expected_pattern", "")
|
|
assert "json" in pattern.lower() or "\\{" in pattern, (
|
|
"Tool call prompt must expect JSON-formatted response"
|
|
)
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not os.environ.get("TURBOQUANT_SERVER_URL"),
|
|
reason="No TurboQuant server available (set TURBOQUANT_SERVER_URL to run)",
|
|
)
|
|
class TestLiveToolCallIntegration:
|
|
"""Live integration test -- requires running llama-server with TurboQuant."""
|
|
|
|
def test_server_health(self):
|
|
"""Server must respond to /v1/models endpoint."""
|
|
import requests
|
|
url = os.environ["TURBOQUANT_SERVER_URL"]
|
|
resp = requests.get(f"{url}/v1/models", timeout=10)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "data" in data
|
|
assert len(data["data"]) > 0
|
|
|
|
def test_tool_call_completion(self):
|
|
"""Model must return a valid tool_call for a read_file prompt."""
|
|
import requests
|
|
url = os.environ["TURBOQUANT_SERVER_URL"]
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"description": "Read a file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"path": {"type": "string"}},
|
|
"required": ["path"],
|
|
},
|
|
},
|
|
}
|
|
]
|
|
resp = requests.post(
|
|
f"{url}/v1/chat/completions",
|
|
json={
|
|
"model": "gemma-4",
|
|
"messages": [
|
|
{"role": "user", "content": "Read the file at /tmp/test.txt"}
|
|
],
|
|
"tools": tools,
|
|
"tool_choice": "auto",
|
|
},
|
|
timeout=120,
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
choice = data["choices"][0]
|
|
msg = choice["message"]
|
|
if "tool_calls" in msg and msg["tool_calls"]:
|
|
tc = msg["tool_calls"][0]
|
|
assert tc["type"] == "function"
|
|
assert tc["function"]["name"] == "read_file"
|
|
args = json.loads(tc["function"]["arguments"])
|
|
assert "path" in args
|
|
else:
|
|
assert len(msg.get("content", "")) > 0
|
|
|
|
def test_tool_call_with_multiple_tools(self):
|
|
"""Model must handle multiple available tools."""
|
|
import requests
|
|
url = os.environ["TURBOQUANT_SERVER_URL"]
|
|
tools = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "read_file",
|
|
"description": "Read a file",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"path": {"type": "string"}},
|
|
"required": ["path"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "web_search",
|
|
"description": "Search the web",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"query": {"type": "string"}},
|
|
"required": ["query"],
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "execute_code",
|
|
"description": "Run Python code",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"code": {"type": "string"}},
|
|
"required": ["code"],
|
|
},
|
|
},
|
|
},
|
|
]
|
|
resp = requests.post(
|
|
f"{url}/v1/chat/completions",
|
|
json={
|
|
"model": "gemma-4",
|
|
"messages": [
|
|
{"role": "user", "content": "Search the web for 'bitcoin price'"}
|
|
],
|
|
"tools": tools,
|
|
"tool_choice": "auto",
|
|
},
|
|
timeout=120,
|
|
)
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "choices" in data
|
|
assert len(data["choices"]) > 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|