From 3c815664e42e9d4b7b5ef4f7fe9af968ae1d1f83 Mon Sep 17 00:00:00 2001 From: Alexander Whitestone Date: Thu, 16 Apr 2026 01:55:05 +0000 Subject: [PATCH] test: add tests for tool calling test suite Refs #101 --- tests/test_tool_calling_suite.py | 125 +++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 tests/test_tool_calling_suite.py diff --git a/tests/test_tool_calling_suite.py b/tests/test_tool_calling_suite.py new file mode 100644 index 00000000..917ad7e4 --- /dev/null +++ b/tests/test_tool_calling_suite.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +""" +Tests for Tool Calling Test Suite. +""" + +import json +import pytest + +# Add benchmarks to path +import sys +import os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) + +from test_tool_calling_1bit import ( + ToolCallType, + ToolCallTest, + TestResult, + BenchmarkResult, + parse_tool_call, + validate_ls, + TOOL_CALL_TESTS, +) + + +class TestToolCallType: + def test_values(self): + assert ToolCallType.FILE_READ.value == "file_read" + assert ToolCallType.TERMINAL_EXEC.value == "terminal_exec" + assert ToolCallType.WEB_SEARCH.value == "web_search" + + +class TestParseToolCall: + def test_json_format(self): + response = '{"tool": "read_file", "params": {"path": "test.txt"}}' + tool, params = parse_tool_call(response) + assert tool == "read_file" + assert params == {"path": "test.txt"} + + def test_json_alt_format(self): + response = '{"name": "terminal", "arguments": {"command": "ls"}}' + tool, params = parse_tool_call(response) + assert tool == "terminal" + assert params == {"command": "ls"} + + def test_function_format(self): + response = 'read_file(path="test.txt")' + tool, params = parse_tool_call(response) + assert tool == "read_file" + assert params.get("path") == "test.txt" + + def test_no_tool(self): + response = "The answer is 4." + tool, params = parse_tool_call(response) + assert tool is None + assert params is None + + def test_embedded_json(self): + response = "I will call the tool: {\"tool\": \"search\", \"params\": {\"query\": \"test\"}}" + tool, params = parse_tool_call(response) + assert tool == "search" + assert params == {"query": "test"} + + +class TestValidateLs: + def test_simple_ls(self): + assert validate_ls({"command": "ls"}) is True + + def test_ls_flags(self): + assert validate_ls({"command": "ls -la"}) is True + assert validate_ls({"command": "ls -l"}) is True + + def test_not_ls(self): + assert validate_ls({"command": "pwd"}) is False + assert validate_ls({"command": "cat file.txt"}) is False + + +class TestToolCallTests: + def test_all_tests_valid(self): + assert len(TOOL_CALL_TESTS) == 10 + for test in TOOL_CALL_TESTS: + assert test.name + assert isinstance(test.tool_type, ToolCallType) + assert test.prompt + + def test_difficulty_distribution(self): + difficulties = [t.difficulty for t in TOOL_CALL_TESTS] + assert "easy" in difficulties + assert "medium" in difficulties + assert "hard" in difficulties + + def test_type_distribution(self): + types = [t.tool_type for t in TOOL_CALL_TESTS] + assert ToolCallType.FILE_READ in types + assert ToolCallType.TERMINAL_EXEC in types + assert ToolCallType.WEB_SEARCH in types + + +class TestTestResult: + def test_creation(self): + result = TestResult("test1", "file_read", True, 100.0, "response") + assert result.test_name == "test1" + assert result.passed is True + assert result.latency_ms == 100.0 + + def test_to_dict(self): + result = TestResult("test1", "file_read", True, 100.0, "response") + d = result.__dict__ + assert "test_name" in d + assert "passed" in d + + +class TestBenchmarkResult: + def test_creation(self): + result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z") + assert result.model == "model1" + assert result.results == [] + + def test_summary(self): + result = BenchmarkResult("model1", "ollama", "2026-01-01T00:00:00Z") + result.results = [ + TestResult("t1", "file_read", True, 100, "r1"), + TestResult("t2", "terminal", False, 200, "r2"), + ] + # Summary would be computed by run_all_tests + assert len(result.results) == 2