test: add tests for tool calling test suite
All checks were successful
Smoke Test / smoke (pull_request) Successful in 19s
All checks were successful
Smoke Test / smoke (pull_request) Successful in 19s
Refs #101
This commit is contained in:
125
tests/test_tool_calling_suite.py
Normal file
125
tests/test_tool_calling_suite.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Tests for Tool Calling Test Suite.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
# Add benchmarks to path
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks"))
|
||||||
|
|
||||||
|
from test_tool_calling_1bit import (
|
||||||
|
ToolCallType,
|
||||||
|
ToolCallTest,
|
||||||
|
TestResult,
|
||||||
|
BenchmarkResult,
|
||||||
|
parse_tool_call,
|
||||||
|
validate_ls,
|
||||||
|
TOOL_CALL_TESTS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallType:
|
||||||
|
def test_values(self):
|
||||||
|
assert ToolCallType.FILE_READ.value == "file_read"
|
||||||
|
assert ToolCallType.TERMINAL_EXEC.value == "terminal_exec"
|
||||||
|
assert ToolCallType.WEB_SEARCH.value == "web_search"
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseToolCall:
|
||||||
|
def test_json_format(self):
|
||||||
|
response = '{"tool": "read_file", "params": {"path": "test.txt"}}'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "read_file"
|
||||||
|
assert params == {"path": "test.txt"}
|
||||||
|
|
||||||
|
def test_json_alt_format(self):
|
||||||
|
response = '{"name": "terminal", "arguments": {"command": "ls"}}'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "terminal"
|
||||||
|
assert params == {"command": "ls"}
|
||||||
|
|
||||||
|
def test_function_format(self):
|
||||||
|
response = 'read_file(path="test.txt")'
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "read_file"
|
||||||
|
assert params.get("path") == "test.txt"
|
||||||
|
|
||||||
|
def test_no_tool(self):
|
||||||
|
response = "The answer is 4."
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool is None
|
||||||
|
assert params is None
|
||||||
|
|
||||||
|
def test_embedded_json(self):
|
||||||
|
response = "I will call the tool: {\"tool\": \"search\", \"params\": {\"query\": \"test\"}}"
|
||||||
|
tool, params = parse_tool_call(response)
|
||||||
|
assert tool == "search"
|
||||||
|
assert params == {"query": "test"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateLs:
|
||||||
|
def test_simple_ls(self):
|
||||||
|
assert validate_ls({"command": "ls"}) is True
|
||||||
|
|
||||||
|
def test_ls_flags(self):
|
||||||
|
assert validate_ls({"command": "ls -la"}) is True
|
||||||
|
assert validate_ls({"command": "ls -l"}) is True
|
||||||
|
|
||||||
|
def test_not_ls(self):
|
||||||
|
assert validate_ls({"command": "pwd"}) is False
|
||||||
|
assert validate_ls({"command": "cat file.txt"}) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolCallTests:
|
||||||
|
def test_all_tests_valid(self):
|
||||||
|
assert len(TOOL_CALL_TESTS) == 10
|
||||||
|
for test in TOOL_CALL_TESTS:
|
||||||
|
assert test.name
|
||||||
|
assert isinstance(test.tool_type, ToolCallType)
|
||||||
|
assert test.prompt
|
||||||
|
|
||||||
|
def test_difficulty_distribution(self):
|
||||||
|
difficulties = [t.difficulty for t in TOOL_CALL_TESTS]
|
||||||
|
assert "easy" in difficulties
|
||||||
|
assert "medium" in difficulties
|
||||||
|
assert "hard" in difficulties
|
||||||
|
|
||||||
|
def test_type_distribution(self):
|
||||||
|
types = [t.tool_type for t in TOOL_CALL_TESTS]
|
||||||
|
assert ToolCallType.FILE_READ in types
|
||||||
|
assert ToolCallType.TERMINAL_EXEC in types
|
||||||
|
assert ToolCallType.WEB_SEARCH in types
|
||||||
|
|
||||||
|
|
||||||
|
class TestTestResult:
|
||||||
|
def test_creation(self):
|
||||||
|
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||||
|
assert result.test_name == "test1"
|
||||||
|
assert result.passed is True
|
||||||
|
assert result.latency_ms == 100.0
|
||||||
|
|
||||||
|
def test_to_dict(self):
|
||||||
|
result = TestResult("test1", "file_read", True, 100.0, "response")
|
||||||
|
d = result.__dict__
|
||||||
|
assert "test_name" in d
|
||||||
|
assert "passed" in d
|
||||||
|
|
||||||
|
|
||||||
|
class TestBenchmarkResult:
|
||||||
|
def test_creation(self):
|
||||||
|
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||||
|
assert result.model == "model1"
|
||||||
|
assert result.results == []
|
||||||
|
|
||||||
|
def test_summary(self):
|
||||||
|
result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z")
|
||||||
|
result.results = [
|
||||||
|
TestResult("t1", "file_read", True, 100, "r1"),
|
||||||
|
TestResult("t2", "terminal", False, 200, "r2"),
|
||||||
|
]
|
||||||
|
# Summary would be computed by run_all_tests
|
||||||
|
assert len(result.results) == 2
|
||||||
Reference in New Issue
Block a user