Compare commits

..

1 Commits

Author SHA1 Message Date
Alexander Whitestone
d1fb50bf2f feat: add Anthropic transport abstraction slice (#951)
All checks were successful
Lint / lint (pull_request) Successful in 8s
- add transport registry, shared transport dataclasses, and AnthropicTransport
- add normalize_anthropic_response_v2 as the bridge from existing Anthropic normalization to shared transport types
- extend Anthropic stop-reason mapping for refusal and model_context_window_exceeded
- add targeted transport and v2 normalization regression tests

Closes #951
Refs #949
2026-04-22 11:20:20 -04:00
11 changed files with 1035 additions and 568 deletions

View File

@@ -1396,6 +1396,8 @@ def normalize_anthropic_response(
"tool_use": "tool_calls",
"max_tokens": "length",
"stop_sequence": "stop",
"refusal": "content_filter",
"model_context_window_exceeded": "length",
}
finish_reason = stop_reason_map.get(response.stop_reason, "stop")
@@ -1409,3 +1411,42 @@ def normalize_anthropic_response(
),
finish_reason,
)
def normalize_anthropic_response_v2(
response,
strip_tool_prefix: bool = False,
) -> "NormalizedResponse":
"""Normalize Anthropic response to NormalizedResponse.
Wraps the existing normalize_anthropic_response() and maps its output
to the shared transport types. This allows incremental migration
without disturbing the legacy call sites.
"""
from agent.transports.types import NormalizedResponse, build_tool_call
assistant_msg, finish_reason = normalize_anthropic_response(response, strip_tool_prefix)
tool_calls = None
if assistant_msg.tool_calls:
tool_calls = [
build_tool_call(
id=tc.id,
name=tc.function.name,
arguments=tc.function.arguments,
)
for tc in assistant_msg.tool_calls
]
provider_data = {}
if getattr(assistant_msg, "reasoning_details", None):
provider_data["reasoning_details"] = assistant_msg.reasoning_details
return NormalizedResponse(
content=assistant_msg.content,
tool_calls=tool_calls,
finish_reason=finish_reason,
reasoning=getattr(assistant_msg, "reasoning", None),
usage=None,
provider_data=provider_data or None,
)

View File

@@ -0,0 +1,57 @@
"""Transport layer types and registry for provider response normalization.
Usage:
from agent.transports import get_transport
transport = get_transport("anthropic_messages")
result = transport.normalize_response(raw_response)
"""
from agent.transports.types import ( # noqa: F401
NormalizedResponse,
ToolCall,
Usage,
build_tool_call,
map_finish_reason,
)
_REGISTRY: dict = {}
def register_transport(api_mode: str, transport_cls: type) -> None:
"""Register a transport class for an api_mode string."""
_REGISTRY[api_mode] = transport_cls
def get_transport(api_mode: str):
"""Get a transport instance for the given api_mode.
Returns None if no transport is registered for this api_mode.
This allows gradual migration — call sites can check for None
and fall back to the legacy code path.
"""
if not _REGISTRY:
_discover_transports()
cls = _REGISTRY.get(api_mode)
if cls is None:
return None
return cls()
def _discover_transports() -> None:
"""Import all transport modules to trigger auto-registration."""
try:
import agent.transports.anthropic # noqa: F401
except ImportError:
pass
try:
import agent.transports.codex # noqa: F401
except ImportError:
pass
try:
import agent.transports.chat_completions # noqa: F401
except ImportError:
pass
try:
import agent.transports.bedrock # noqa: F401
except ImportError:
pass

View File

@@ -0,0 +1,95 @@
"""Anthropic Messages API transport.
Delegates to the existing adapter functions in agent/anthropic_adapter.py.
This transport owns format conversion and normalization — NOT client lifecycle.
"""
from typing import Any, Dict, List, Optional
from agent.transports.base import ProviderTransport
from agent.transports.types import NormalizedResponse
class AnthropicTransport(ProviderTransport):
"""Transport for api_mode='anthropic_messages'."""
@property
def api_mode(self) -> str:
return "anthropic_messages"
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
from agent.anthropic_adapter import convert_messages_to_anthropic
base_url = kwargs.get("base_url")
return convert_messages_to_anthropic(messages, base_url=base_url)
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
from agent.anthropic_adapter import convert_tools_to_anthropic
return convert_tools_to_anthropic(tools)
def build_kwargs(
self,
model: str,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**params,
) -> Dict[str, Any]:
from agent.anthropic_adapter import build_anthropic_kwargs
return build_anthropic_kwargs(
model=model,
messages=messages,
tools=tools,
max_tokens=params.get("max_tokens", 16384),
reasoning_config=params.get("reasoning_config"),
tool_choice=params.get("tool_choice"),
is_oauth=params.get("is_oauth", False),
preserve_dots=params.get("preserve_dots", False),
context_length=params.get("context_length"),
base_url=params.get("base_url"),
fast_mode=params.get("fast_mode", False),
)
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
from agent.anthropic_adapter import normalize_anthropic_response_v2
strip_tool_prefix = kwargs.get("strip_tool_prefix", False)
return normalize_anthropic_response_v2(response, strip_tool_prefix=strip_tool_prefix)
def validate_response(self, response: Any) -> bool:
if response is None:
return False
content_blocks = getattr(response, "content", None)
if not isinstance(content_blocks, list):
return False
if not content_blocks:
return False
return True
def extract_cache_stats(self, response: Any):
usage = getattr(response, "usage", None)
if usage is None:
return None
cached = getattr(usage, "cache_read_input_tokens", 0) or 0
written = getattr(usage, "cache_creation_input_tokens", 0) or 0
if cached or written:
return {"cached_tokens": cached, "creation_tokens": written}
return None
_STOP_REASON_MAP = {
"end_turn": "stop",
"tool_use": "tool_calls",
"max_tokens": "length",
"stop_sequence": "stop",
"refusal": "content_filter",
"model_context_window_exceeded": "length",
}
def map_finish_reason(self, raw_reason: str) -> str:
return self._STOP_REASON_MAP.get(raw_reason, "stop")
from agent.transports import register_transport # noqa: E402
register_transport("anthropic_messages", AnthropicTransport)

61
agent/transports/base.py Normal file
View File

@@ -0,0 +1,61 @@
"""Abstract base for provider transports.
A transport owns the data path for one api_mode:
convert_messages → convert_tools → build_kwargs → normalize_response
It does NOT own: client construction, streaming, credential refresh,
prompt caching, interrupt handling, or retry logic. Those stay on AIAgent.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from agent.transports.types import NormalizedResponse
class ProviderTransport(ABC):
"""Base class for provider-specific format conversion and normalization."""
@property
@abstractmethod
def api_mode(self) -> str:
"""The api_mode string this transport handles."""
...
@abstractmethod
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
"""Convert OpenAI-format messages to provider-native format."""
...
@abstractmethod
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
"""Convert OpenAI-format tool definitions to provider-native format."""
...
@abstractmethod
def build_kwargs(
self,
model: str,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
**params,
) -> Dict[str, Any]:
"""Build the complete provider kwargs dict."""
...
@abstractmethod
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
"""Normalize a raw provider response to the shared NormalizedResponse type."""
...
def validate_response(self, response: Any) -> bool:
"""Optional structural validation for raw responses."""
return True
def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]:
"""Optional cache stats extraction."""
return None
def map_finish_reason(self, raw_reason: str) -> str:
"""Optional stop-reason mapping. Defaults to passthrough."""
return raw_reason

58
agent/transports/types.py Normal file
View File

@@ -0,0 +1,58 @@
"""Shared types for normalized provider responses."""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
@dataclass
class ToolCall:
"""A normalized tool call from any provider."""
id: Optional[str]
name: str
arguments: str
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
@dataclass
class Usage:
"""Token usage from an API response."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
cached_tokens: int = 0
@dataclass
class NormalizedResponse:
"""Normalized API response from any provider."""
content: Optional[str]
tool_calls: Optional[List[ToolCall]]
finish_reason: str
reasoning: Optional[str] = None
usage: Optional[Usage] = None
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
def build_tool_call(
id: Optional[str],
name: str,
arguments: Any,
**provider_fields: Any,
) -> ToolCall:
"""Build a ToolCall, auto-serialising dict arguments."""
args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
provider_data = dict(provider_fields) if provider_fields else None
return ToolCall(id=id, name=name, arguments=args_str, provider_data=provider_data)
def map_finish_reason(reason: Optional[str], mapping: Dict[str, str]) -> str:
"""Translate a provider-specific stop reason to the normalized set."""
if reason is None:
return "stop"
return mapping.get(reason, "stop")

View File

@@ -1,139 +0,0 @@
# Tool-Calling Benchmark Report
Generated: 2026-04-22 15:46 UTC
Executed: 3 calls from a 100-call suite across 7 categories
Models tested: nous:gia-3/gemma-4-31b, gemini:gemma-4-26b-it, nous:mimo-v2-pro
## Requested category mix
| Category | Target calls |
|----------|--------------|
| file | 20 |
| terminal | 20 |
| web | 15 |
| code | 15 |
| browser | 10 |
| delegate | 10 |
| mcp | 10 |
## Summary
| Metric | nous:gia-3/gemma-4-31b | gemini:gemma-4-26b-it | nous:mimo-v2-pro |
|--------|---------|---------|---------|
| Schema parse success | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Tool execution success | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Parallel tool success | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Avg latency (s) | 0.00 | 0.00 | 0.00 |
| Avg tokens per call | 0.0 | 0.0 | 0.0 |
| Avg token cost per call (USD) | n/a | n/a | n/a |
| Skipped / unavailable | 0/1 | 0/1 | 0/1 |
## Per-category breakdown
### File
| Metric | nous:gia-3/gemma-4-31b | gemini:gemma-4-26b-it | nous:mimo-v2-pro |
|--------|---------|---------|---------|
| Schema OK | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Exec OK | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Parallel OK | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Correct tool | 0/1 (0%) | 0/1 (0%) | 0/1 (0%) |
| Avg tokens | 0.0 | 0.0 | 0.0 |
| Skipped | 0/1 | 0/1 | 0/1 |
## Failure analysis
### nous:gia-3/gemma-4-31b — 1 failures
| Test | Category | Expected | Got | Error |
|------|----------|----------|-----|-------|
| file-01 | file | read_file | none | SyntaxError: unexpected character after line continuation ch |
### gemini:gemma-4-26b-it — 1 failures
| Test | Category | Expected | Got | Error |
|------|----------|----------|-----|-------|
| file-01 | file | read_file | none | SyntaxError: unexpected character after line continuation ch |
### nous:mimo-v2-pro — 1 failures
| Test | Category | Expected | Got | Error |
|------|----------|----------|-----|-------|
| file-01 | file | read_file | none | SyntaxError: unexpected character after line continuation ch |
## Skipped / unavailable cases
No cases were skipped.
## Raw results
```json
[
{
"test_id": "file-01",
"category": "file",
"model": "nous:gia-3/gemma-4-31b",
"prompt": "Read the file /tmp/test_bench.txt and show me its contents.",
"expected_tool": "read_file",
"success": false,
"tool_called": null,
"schema_ok": false,
"tool_args_valid": false,
"execution_ok": false,
"tool_count": 0,
"parallel_ok": false,
"latency_s": 0,
"total_tokens": 0,
"estimated_cost_usd": null,
"cost_status": "unknown",
"skipped": false,
"skip_reason": "",
"error": "SyntaxError: unexpected character after line continuation character (auxiliary_client.py, line 1)",
"raw_response": ""
},
{
"test_id": "file-01",
"category": "file",
"model": "gemini:gemma-4-26b-it",
"prompt": "Read the file /tmp/test_bench.txt and show me its contents.",
"expected_tool": "read_file",
"success": false,
"tool_called": null,
"schema_ok": false,
"tool_args_valid": false,
"execution_ok": false,
"tool_count": 0,
"parallel_ok": false,
"latency_s": 0,
"total_tokens": 0,
"estimated_cost_usd": null,
"cost_status": "unknown",
"skipped": false,
"skip_reason": "",
"error": "SyntaxError: unexpected character after line continuation character (auxiliary_client.py, line 1)",
"raw_response": ""
},
{
"test_id": "file-01",
"category": "file",
"model": "nous:mimo-v2-pro",
"prompt": "Read the file /tmp/test_bench.txt and show me its contents.",
"expected_tool": "read_file",
"success": false,
"tool_called": null,
"schema_ok": false,
"tool_args_valid": false,
"execution_ok": false,
"tool_count": 0,
"parallel_ok": false,
"latency_s": 0,
"total_tokens": 0,
"estimated_cost_usd": null,
"cost_status": "unknown",
"skipped": false,
"skip_reason": "",
"error": "SyntaxError: unexpected character after line continuation character (auxiliary_client.py, line 1)",
"raw_response": ""
}
]
```

View File

@@ -8,11 +8,10 @@ success rates, latency, and token costs.
Usage:
python3 benchmarks/tool_call_benchmark.py # full 100-call suite
python3 benchmarks/tool_call_benchmark.py --limit 10 # quick smoke test
python3 benchmarks/tool_call_benchmark.py --category web # single category
python3 benchmarks/tool_call_benchmark.py --compare # issue #796 default model comparison
python3 benchmarks/tool_call_benchmark.py --models nous # single model
python3 benchmarks/tool_call_benchmark.py --category file # single category
Requires: hermes-agent venv activated, provider credentials for the selected models,
and any optional browser/MCP/web backends you want to include in the run.
Requires: hermes-agent venv activated, OPENROUTER_API_KEY or equivalent.
"""
import argparse
@@ -26,12 +25,10 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import Optional
# Ensure hermes-agent root is importable before local package imports.
# Ensure hermes-agent root is importable
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from agent.usage_pricing import CanonicalUsage, estimate_usage_cost
# ---------------------------------------------------------------------------
# Test Definitions
# ---------------------------------------------------------------------------
@@ -42,11 +39,9 @@ class ToolCall:
id: str
category: str
prompt: str
expected_tool: str # exact tool name we expect the model to call
expected_params_check: str = "" # substring expected in JSON args
expected_tool_prefix: str = "" # prefix match for dynamic surfaces like mcp_*
expects_parallel: bool = False # whether this prompt should elicit multiple tool calls
timeout: int = 30 # max seconds per call
expected_tool: str # tool name we expect the model to call
expected_params_check: str = "" # substring expected in JSON args
timeout: int = 30 # max seconds per call
notes: str = ""
@@ -190,107 +185,85 @@ SUITE: list[ToolCall] = [
ToolCall("deleg-10", "delegate", "Delegate: create a temp file /tmp/bench_deleg.txt with 'done'.",
"delegate_task", "write"),
# ── Web Search & Extraction (15) ─────────────────────────────────────
ToolCall("web-01", "web", "Search the web for Python dataclasses documentation.",
"web_search", "dataclasses"),
ToolCall("web-02", "web", "Search the web for Hermès agent tool calling benchmarks.",
"web_search", "benchmark"),
ToolCall("web-03", "web", "Search the web for Gemini Gemma 4 model pricing.",
"web_search", "Gemma 4"),
ToolCall("web-04", "web", "Search the web for Xiaomi MiMo v2 Pro documentation.",
"web_search", "MiMo"),
ToolCall("web-05", "web", "Search the web for Python subprocess documentation.",
"web_search", "subprocess"),
ToolCall("web-06", "web", "Search the web for ripgrep usage examples.",
"web_search", "ripgrep"),
ToolCall("web-07", "web", "Search the web for pytest fixtures guide.",
"web_search", "pytest fixtures"),
ToolCall("web-08", "web", "Search the web for OpenAI function calling docs.",
"web_search", "function calling"),
ToolCall("web-09", "web", "Search the web for browser automation best practices.",
"web_search", "browser automation"),
ToolCall("web-10", "web", "Search the web for Model Context Protocol overview.",
"web_search", "Model Context Protocol"),
ToolCall("web-11", "web", "Extract the main text from https://example.com.",
"web_extract", "example.com"),
ToolCall("web-12", "web", "Extract the page content from https://example.org.",
"web_extract", "example.org"),
ToolCall("web-13", "web", "Extract the title and body text from https://www.iana.org/domains/reserved.",
"web_extract", "iana.org"),
ToolCall("web-14", "web", "Extract content from https://httpbin.org/html.",
"web_extract", "httpbin.org"),
ToolCall("web-15", "web", "Extract the main content from https://www.python.org/.",
"web_extract", "python.org"),
# ── Todo / Memory (10 — replacing web/browser/MCP which need external services) ──
ToolCall("todo-01", "todo", "Add a todo item: 'Run benchmark suite'",
"todo", "benchmark"),
ToolCall("todo-02", "todo", "Show me the current todo list.",
"todo", ""),
ToolCall("todo-03", "todo", "Mark the first todo item as completed.",
"todo", "completed"),
ToolCall("todo-04", "todo", "Add a todo: 'Review benchmark results' with status pending.",
"todo", "Review"),
ToolCall("todo-05", "todo", "Clear all completed todos.",
"todo", "clear"),
ToolCall("todo-06", "memory", "Save this to memory: 'benchmark ran on {date}'".format(
date=datetime.now().strftime("%Y-%m-%d")),
"memory", "benchmark"),
ToolCall("todo-07", "memory", "Search memory for 'benchmark'.",
"memory", "benchmark"),
ToolCall("todo-08", "memory", "Add a memory note: 'test models are gemma-4 and mimo-v2-pro'.",
"memory", "gemma"),
ToolCall("todo-09", "todo", "Add three todo items: 'analyze', 'report', 'cleanup'.",
"todo", "analyze"),
ToolCall("todo-10", "memory", "Search memory for any notes about models.",
"memory", "model"),
# ── Browser Automation (10) ───────────────────────────────────────────
ToolCall("browser-01", "browser", "Open https://example.com in the browser.",
"browser_navigate", "example.com"),
ToolCall("browser-02", "browser", "Open https://www.python.org in the browser.",
"browser_navigate", "python.org"),
ToolCall("browser-03", "browser", "Open https://www.wikipedia.org in the browser.",
"browser_navigate", "wikipedia.org"),
ToolCall("browser-04", "browser", "Navigate the browser to https://example.org.",
"browser_navigate", "example.org"),
ToolCall("browser-05", "browser", "Go to https://httpbin.org/forms/post in the browser.",
"browser_navigate", "httpbin.org/forms/post"),
ToolCall("browser-06", "browser", "Open https://www.iana.org/domains/reserved in the browser.",
"browser_navigate", "iana.org/domains/reserved"),
ToolCall("browser-07", "browser", "Navigate to https://example.net in the browser.",
"browser_navigate", "example.net"),
ToolCall("browser-08", "browser", "Open https://developer.mozilla.org in the browser.",
"browser_navigate", "developer.mozilla.org"),
ToolCall("browser-09", "browser", "Navigate the browser to https://www.rfc-editor.org.",
"browser_navigate", "rfc-editor.org"),
ToolCall("browser-10", "browser", "Open https://www.gnu.org in the browser.",
"browser_navigate", "gnu.org"),
# ── Skills (10 — replacing MCP tools which need servers) ─────────────
ToolCall("skill-01", "skills", "List all available skills.",
"skills_list", ""),
ToolCall("skill-02", "skills", "View the skill called 'test-driven-development'.",
"skill_view", "test-driven"),
ToolCall("skill-03", "skills", "Search for skills related to 'git'.",
"skills_list", "git"),
ToolCall("skill-04", "skills", "View the 'code-review' skill.",
"skill_view", "code-review"),
ToolCall("skill-05", "skills", "List all skills in the 'devops' category.",
"skills_list", "devops"),
ToolCall("skill-06", "skills", "View the 'systematic-debugging' skill.",
"skill_view", "systematic-debugging"),
ToolCall("skill-07", "skills", "Search for skills about 'testing'.",
"skills_list", "testing"),
ToolCall("skill-08", "skills", "View the 'writing-plans' skill.",
"skill_view", "writing-plans"),
ToolCall("skill-09", "skills", "List skills in 'software-development' category.",
"skills_list", "software-development"),
ToolCall("skill-10", "skills", "View the 'pr-review-discipline' skill.",
"skill_view", "pr-review"),
# ── MCP Tools (10) ────────────────────────────────────────────────────
ToolCall("mcp-01", "mcp", "Use an available MCP tool to list configured MCP resources or prompts.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-02", "mcp", "Use an MCP tool to inspect available resources on a configured server.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-03", "mcp", "Use an MCP tool to read a resource from any configured MCP server.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-04", "mcp", "Use an MCP tool to list prompts from any configured MCP server.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-05", "mcp", "Use an available MCP tool and report what it returns.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-06", "mcp", "Call any safe MCP tool that is currently available and summarize the response.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-07", "mcp", "Use one configured MCP tool to enumerate data or capabilities.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-08", "mcp", "Use an MCP tool to fetch a small piece of data from a connected server.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-09", "mcp", "Invoke an available MCP tool and show the structured result.",
"", "", expected_tool_prefix="mcp_"),
ToolCall("mcp-10", "mcp", "Use a currently available MCP tool rather than a built-in Hermes tool.",
"", "", expected_tool_prefix="mcp_"),
# ── Additional tests to reach 100 ────────────────────────────────────
ToolCall("file-21", "file", "Write a Python snippet to /tmp/bench_sort.py that sorts [3,1,2].",
"write_file", "bench_sort"),
ToolCall("file-22", "file", "Read /tmp/bench_sort.py back and confirm it exists.",
"read_file", "bench_sort"),
ToolCall("file-23", "file", "Search for 'class' in all .py files in the benchmarks directory.",
"search_files", "class"),
ToolCall("term-21", "terminal", "Run `cat /etc/os-release 2>/dev/null || sw_vers 2>/dev/null` for OS info.",
"terminal", "os"),
ToolCall("term-22", "terminal", "Run `nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null` for CPU count.",
"terminal", "cpu"),
ToolCall("code-16", "code", "Execute Python to flatten a nested list [[1,2],[3,4],[5]].",
"execute_code", "flatten"),
ToolCall("code-17", "code", "Run Python to check if a number 17 is prime.",
"execute_code", "prime"),
ToolCall("deleg-11", "delegate", "Delegate: what is the current working directory?",
"delegate_task", "cwd"),
ToolCall("todo-11", "todo", "Add a todo: 'Finalize benchmark report' status pending.",
"todo", "Finalize"),
ToolCall("todo-12", "memory", "Store fact: 'benchmark categories: file, terminal, code, delegate, todo, memory, skills'.",
"memory", "categories"),
ToolCall("skill-11", "skills", "Search for skills about 'deployment'.",
"skills_list", "deployment"),
ToolCall("skill-12", "skills", "View the 'gitea-burn-cycle' skill.",
"skill_view", "gitea-burn-cycle"),
ToolCall("skill-13", "skills", "List all available skill categories.",
"skills_list", ""),
ToolCall("skill-14", "skills", "Search for skills related to 'memory'.",
"skills_list", "memory"),
ToolCall("skill-15", "skills", "View the 'mimo-swarm' skill.",
"skill_view", "mimo-swarm"),
]
# fmt: on
DEFAULT_COMPARE_MODELS = [
"nous:gia-3/gemma-4-31b",
"gemini:gemma-4-26b-it",
"nous:mimo-v2-pro",
]
ISSUE_796_CATEGORY_COUNTS = {
"file": 20,
"terminal": 20,
"web": 15,
"code": 15,
"browser": 10,
"delegate": 10,
"mcp": 10,
}
def suite_category_counts() -> dict[str, int]:
counts: dict[str, int] = {}
for tc in SUITE:
counts[tc.category] = counts.get(tc.category, 0) + 1
return counts
# ---------------------------------------------------------------------------
# Runner
@@ -305,17 +278,9 @@ class CallResult:
expected_tool: str
success: bool
tool_called: Optional[str] = None
schema_ok: bool = False
tool_args_valid: bool = False
execution_ok: bool = False
tool_count: int = 0
parallel_ok: bool = False
latency_s: float = 0.0
total_tokens: int = 0
estimated_cost_usd: Optional[float] = None
cost_status: str = "unknown"
skipped: bool = False
skip_reason: str = ""
error: str = ""
raw_response: str = ""
@@ -326,12 +291,7 @@ class ModelStats:
total: int = 0
schema_ok: int = 0 # model produced valid tool call JSON
exec_ok: int = 0 # tool actually ran without error
parallel_ok: int = 0 # calls with 2+ tool calls that executed successfully
skipped: int = 0
latency_sum: float = 0.0
total_tokens: int = 0
total_cost_usd: float = 0.0
known_cost_calls: int = 0
failures: list = field(default_factory=list)
@property
@@ -346,10 +306,6 @@ class ModelStats:
def avg_latency(self) -> float:
return (self.latency_sum / self.total) if self.total else 0
@property
def avg_cost_usd(self) -> Optional[float]:
return (self.total_cost_usd / self.known_cost_calls) if self.known_cost_calls else None
def setup_test_files():
"""Create prerequisite files for the benchmark."""
@@ -362,38 +318,20 @@ def setup_test_files():
)
def _matches_expected_tool(test_case: ToolCall, tool_name: str) -> bool:
if test_case.expected_tool and tool_name == test_case.expected_tool:
return True
if test_case.expected_tool_prefix and tool_name.startswith(test_case.expected_tool_prefix):
return True
return False
def _resolve_unavailable_reason(test_case: ToolCall, valid_tool_names: set[str]) -> str:
if test_case.expected_tool and test_case.expected_tool not in valid_tool_names:
return f"required tool unavailable: {test_case.expected_tool}"
if test_case.expected_tool_prefix and not any(
name.startswith(test_case.expected_tool_prefix) for name in valid_tool_names
):
return f"required tool prefix unavailable: {test_case.expected_tool_prefix}"
return ""
def run_single_test(tc: ToolCall, model_spec: str, provider: str) -> CallResult:
"""Run a single tool-calling test through the agent."""
from run_agent import AIAgent
result = CallResult(
test_id=tc.id,
category=tc.category,
model=model_spec,
prompt=tc.prompt,
expected_tool=tc.expected_tool or tc.expected_tool_prefix,
expected_tool=tc.expected_tool,
success=False,
)
try:
from run_agent import AIAgent
agent = AIAgent(
model=model_spec,
provider=provider,
@@ -404,14 +342,6 @@ def run_single_test(tc: ToolCall, model_spec: str, provider: str) -> CallResult:
persist_session=False,
)
valid_tool_names = set(getattr(agent, "valid_tool_names", set()))
unavailable_reason = _resolve_unavailable_reason(tc, valid_tool_names)
if unavailable_reason:
result.skipped = True
result.skip_reason = unavailable_reason
result.error = unavailable_reason
return result
t0 = time.time()
conv = agent.run_conversation(
user_message=tc.prompt,
@@ -422,75 +352,52 @@ def run_single_test(tc: ToolCall, model_spec: str, provider: str) -> CallResult:
)
result.latency_s = round(time.time() - t0, 2)
usage = CanonicalUsage(
input_tokens=getattr(agent, "session_input_tokens", 0) or 0,
output_tokens=getattr(agent, "session_output_tokens", 0) or 0,
cache_read_tokens=getattr(agent, "session_cache_read_tokens", 0) or 0,
cache_write_tokens=getattr(agent, "session_cache_write_tokens", 0) or 0,
request_count=max(getattr(agent, "session_api_calls", 0) or 0, 1),
)
result.total_tokens = usage.total_tokens
billed_model = model_spec.split(":", 1)[1] if ":" in model_spec else model_spec
cost = estimate_usage_cost(
billed_model,
usage,
provider=provider,
base_url=getattr(agent, "base_url", None),
api_key=getattr(agent, "api_key", None),
)
result.cost_status = cost.status
result.estimated_cost_usd = float(cost.amount_usd) if cost.amount_usd is not None else None
messages = conv.get("messages", [])
tool_calls = []
# Find the first assistant message with tool_calls
tool_called = None
tool_args_str = ""
for msg in messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
tool_calls = list(msg["tool_calls"])
for tc_item in msg["tool_calls"]:
fn = tc_item.get("function", {})
tool_called = fn.get("name", "")
tool_args_str = fn.get("arguments", "{}")
break
break
if tool_calls:
result.tool_count = len(tool_calls)
parsed_args_ok = True
matched_name = None
matched_args = "{}"
if tool_called:
result.tool_called = tool_called
result.schema_ok = True
for tc_item in tool_calls:
fn = tc_item.get("function", {})
tool_name = fn.get("name", "")
tool_args = fn.get("arguments", "{}")
try:
json.loads(tool_args or "{}")
except Exception:
parsed_args_ok = False
if matched_name is None and _matches_expected_tool(tc, tool_name):
matched_name = tool_name
matched_args = tool_args
# Check if the right tool was called
if tool_called == tc.expected_tool:
result.success = True
result.schema_ok = parsed_args_ok
result.tool_called = matched_name or tool_calls[0].get("function", {}).get("name", "")
if matched_name:
result.tool_args_valid = (
tc.expected_params_check in matched_args if tc.expected_params_check else True
)
result.success = result.schema_ok and result.tool_args_valid
# Check if args contain expected substring
if tc.expected_params_check:
result.tool_args_valid = tc.expected_params_check in tool_args_str
else:
result.tool_args_valid = True
# Check if tool executed (look for tool role message)
for msg in messages:
if msg.get("role") == "tool":
content = msg.get("content", "")
if content:
if content and "error" not in content.lower()[:50]:
result.execution_ok = True
break
result.parallel_ok = result.tool_count > 1 and result.execution_ok
elif content:
result.execution_ok = True # got a response, even if error
break
else:
# No tool call produced — still check if model responded
final = conv.get("final_response", "")
result.raw_response = final[:200] if final else ""
except Exception as e:
result.error = f"{type(e).__name__}: {str(e)[:200]}"
result.latency_s = round(time.time() - t0, 2) if 't0' in locals() else 0
result.latency_s = round(time.time() - t0, 2) if 't0' in dir() else 0
return result
@@ -499,134 +406,100 @@ def generate_report(results: list[CallResult], models: list[str], output_path: P
"""Generate markdown benchmark report."""
now = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M UTC")
stats: dict[str, ModelStats] = {m: ModelStats(model=m) for m in models}
# Aggregate per model
stats: dict[str, ModelStats] = {}
for m in models:
stats[m] = ModelStats(model=m)
by_category: dict[str, dict[str, list[CallResult]]] = {}
for r in results:
s = stats[r.model]
s.total += 1
s.schema_ok += int(r.schema_ok)
s.exec_ok += int(r.execution_ok)
s.latency_sum += r.latency_s
s.total_tokens += r.total_tokens
if r.estimated_cost_usd is not None:
s.total_cost_usd += r.estimated_cost_usd
s.known_cost_calls += 1
if r.skipped:
s.skipped += 1
else:
s.schema_ok += int(r.schema_ok)
s.exec_ok += int(r.execution_ok)
s.parallel_ok += int(r.parallel_ok)
if not r.success:
s.failures.append(r)
if not r.success:
s.failures.append(r)
by_category.setdefault(r.category, {}).setdefault(r.model, []).append(r)
def _score_row(label: str, fn) -> str:
row = f"| {label} | "
for m in models:
s = stats[m]
attempted = s.total - s.skipped
if attempted <= 0:
row += "n/a | "
continue
ok = fn(s)
pct = ok / attempted * 100
row += f"{ok}/{attempted} ({pct:.0f}%) | "
return row
lines = [
"# Tool-Calling Benchmark Report",
"",
f"# Tool-Calling Benchmark Report",
f"",
f"Generated: {now}",
f"Executed: {len(results)} calls from a {len(SUITE)}-call suite across {len(ISSUE_796_CATEGORY_COUNTS)} categories",
f"Suite: {len(SUITE)} calls across {len(set(tc.category for tc in SUITE))} categories",
f"Models tested: {', '.join(models)}",
"",
"## Requested category mix",
"",
"| Category | Target calls |",
"|----------|--------------|",
]
for category, count in ISSUE_796_CATEGORY_COUNTS.items():
lines.append(f"| {category} | {count} |")
lines.extend([
"",
"## Summary",
"",
f"",
f"## Summary",
f"",
f"| Metric | {' | '.join(models)} |",
f"|--------|{'|'.join('---------' for _ in models)}|",
_score_row("Schema parse success", lambda s: s.schema_ok),
_score_row("Tool execution success", lambda s: s.exec_ok),
_score_row("Parallel tool success", lambda s: s.parallel_ok),
])
]
row = "| Avg latency (s) | "
for m in models:
row += f"{stats[m].avg_latency:.2f} | "
lines.append(row)
row = "| Avg tokens per call | "
for m in models:
total = stats[m].total
avg_tokens = stats[m].total_tokens / total if total else 0
row += f"{avg_tokens:.1f} | "
lines.append(row)
row = "| Avg token cost per call (USD) | "
for m in models:
avg_cost = stats[m].avg_cost_usd
row += (f"{avg_cost:.6f} | " if avg_cost is not None else "n/a | ")
lines.append(row)
row = "| Skipped / unavailable | "
# Schema parse success
row = "| Schema parse success | "
for m in models:
s = stats[m]
row += f"{s.skipped}/{s.total} | "
row += f"{s.schema_ok}/{s.total} ({s.schema_pct:.0f}%) | "
lines.append(row)
# Tool execution success
row = "| Tool execution success | "
for m in models:
s = stats[m]
row += f"{s.exec_ok}/{s.total} ({s.exec_pct:.0f}%) | "
lines.append(row)
# Correct tool selected
row = "| Correct tool selected | "
for m in models:
s = stats[m]
correct = sum(1 for r in results if r.model == m and r.success)
pct = (correct / s.total * 100) if s.total else 0
row += f"{correct}/{s.total} ({pct:.0f}%) | "
lines.append(row)
# Avg latency
row = "| Avg latency (s) | "
for m in models:
s = stats[m]
row += f"{s.avg_latency:.2f} | "
lines.append(row)
lines.append("")
lines.append("## Per-category breakdown")
# Per-category breakdown
lines.append("## Per-Category Breakdown")
lines.append("")
for cat in sorted(by_category.keys()):
lines.append(f"### {cat.title()}")
lines.append("")
lines.append(f"| Metric | {' | '.join(models)} |")
lines.append(f"|--------|{'|'.join('---------' for _ in models)}|")
cat_data = by_category[cat]
for metric_name, fn in [
("Schema OK", lambda r: r.schema_ok),
("Exec OK", lambda r: r.execution_ok),
("Parallel OK", lambda r: r.parallel_ok),
("Correct tool", lambda r: r.success),
]:
row = f"| {metric_name} | "
for m in models:
results_m = by_category[cat].get(m, [])
attempted = [r for r in results_m if not r.skipped]
if not attempted:
row += "n/a | "
continue
ok = sum(1 for r in attempted if fn(r))
pct = ok / len(attempted) * 100
row += f"{ok}/{len(attempted)} ({pct:.0f}%) | "
results_m = cat_data.get(m, [])
total = len(results_m)
ok = sum(1 for r in results_m if fn(r))
pct = (ok / total * 100) if total else 0
row += f"{ok}/{total} ({pct:.0f}%) | "
lines.append(row)
row = "| Avg tokens | "
for m in models:
results_m = by_category[cat].get(m, [])
avg_tokens = sum(r.total_tokens for r in results_m) / len(results_m) if results_m else 0
row += f"{avg_tokens:.1f} | "
lines.append(row)
row = "| Skipped | "
for m in models:
results_m = by_category[cat].get(m, [])
skipped = sum(1 for r in results_m if r.skipped)
row += f"{skipped}/{len(results_m)} | "
lines.append(row)
lines.append("")
lines.append("## Failure analysis")
# Failure analysis
lines.append("## Failure Analysis")
lines.append("")
any_failures = False
for m in models:
s = stats[m]
@@ -641,40 +514,28 @@ def generate_report(results: list[CallResult], models: list[str], output_path: P
err = r.error or "wrong tool"
lines.append(f"| {r.test_id} | {r.category} | {r.expected_tool} | {got} | {err[:60]} |")
lines.append("")
if not any_failures:
lines.append("No model failures detected.")
lines.append("No failures detected.")
lines.append("")
skipped_results = [r for r in results if r.skipped]
lines.append("## Skipped / unavailable cases")
lines.append("")
if skipped_results:
lines.append("| Test | Model | Category | Reason |")
lines.append("|------|-------|----------|--------|")
for r in skipped_results:
lines.append(f"| {r.test_id} | {r.model} | {r.category} | {r.skip_reason[:80]} |")
else:
lines.append("No cases were skipped.")
lines.append("")
lines.append("## Raw results")
# Raw results JSON
lines.append("## Raw Results")
lines.append("")
lines.append("```json")
lines.append(json.dumps([asdict(r) for r in results], indent=2, default=str))
lines.append("```")
report = "\n".join(lines)
output_path.write_text(report, encoding="utf-8")
output_path.write_text(report)
return report
def main():
parser = argparse.ArgumentParser(description="Tool-calling benchmark")
parser.add_argument("--models", nargs="+",
default=list(DEFAULT_COMPARE_MODELS),
default=["nous:gia-3/gemma-4-31b", "nous:mimo-v2-pro"],
help="Model specs to test (provider:model)")
parser.add_argument("--compare", action="store_true",
help="Use the issue #796 default comparison set")
parser.add_argument("--limit", type=int, default=0,
help="Run only first N tests (0 = all)")
parser.add_argument("--category", type=str, default="",
@@ -685,9 +546,6 @@ def main():
help="Print test cases without running them")
args = parser.parse_args()
if args.compare:
args.models = list(DEFAULT_COMPARE_MODELS)
# Filter suite
suite = SUITE[:]
if args.category:

View File

@@ -0,0 +1,213 @@
"""Regression tests: normalize_anthropic_response_v2 vs v1.
Constructs mock Anthropic responses and asserts that the v2 function
(returning NormalizedResponse) produces identical field values to the
original v1 function (returning SimpleNamespace + finish_reason).
"""
from types import SimpleNamespace
import pytest
from agent.anthropic_adapter import (
normalize_anthropic_response,
normalize_anthropic_response_v2,
)
from agent.transports.types import NormalizedResponse
def _text_block(text: str):
return SimpleNamespace(type="text", text=text)
def _thinking_block(thinking: str, signature: str = "sig_abc"):
return SimpleNamespace(type="thinking", thinking=thinking, signature=signature)
def _tool_use_block(id: str, name: str, input: dict):
return SimpleNamespace(type="tool_use", id=id, name=name, input=input)
def _response(content_blocks, stop_reason="end_turn"):
return SimpleNamespace(
content=content_blocks,
stop_reason=stop_reason,
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
)
class TestTextOnly:
def setup_method(self):
self.resp = _response([_text_block("Hello world")])
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_type(self):
assert isinstance(self.v2, NormalizedResponse)
def test_content_matches(self):
assert self.v2.content == self.v1_msg.content
def test_finish_reason_matches(self):
assert self.v2.finish_reason == self.v1_finish
def test_no_tool_calls(self):
assert self.v2.tool_calls is None
assert self.v1_msg.tool_calls is None
def test_no_reasoning(self):
assert self.v2.reasoning is None
assert self.v1_msg.reasoning is None
class TestWithToolCalls:
def setup_method(self):
self.resp = _response(
[
_text_block("I'll check that"),
_tool_use_block("toolu_abc", "terminal", {"command": "ls"}),
_tool_use_block("toolu_def", "read_file", {"path": "/tmp"}),
],
stop_reason="tool_use",
)
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_finish_reason(self):
assert self.v2.finish_reason == "tool_calls"
assert self.v1_finish == "tool_calls"
def test_tool_call_count(self):
assert len(self.v2.tool_calls) == 2
assert len(self.v1_msg.tool_calls) == 2
def test_tool_call_ids_match(self):
for i in range(2):
assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id
def test_tool_call_names_match(self):
assert self.v2.tool_calls[0].name == "terminal"
assert self.v2.tool_calls[1].name == "read_file"
for i in range(2):
assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name
def test_tool_call_arguments_match(self):
for i in range(2):
assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments
def test_content_preserved(self):
assert self.v2.content == self.v1_msg.content
assert "check that" in self.v2.content
class TestWithThinking:
def setup_method(self):
self.resp = _response([
_thinking_block("Let me think about this carefully..."),
_text_block("The answer is 42."),
])
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_reasoning_matches(self):
assert self.v2.reasoning == self.v1_msg.reasoning
assert "think about this" in self.v2.reasoning
def test_reasoning_details_in_provider_data(self):
v1_details = self.v1_msg.reasoning_details
v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None
assert v1_details is not None
assert v2_details is not None
assert len(v2_details) == len(v1_details)
def test_content_excludes_thinking(self):
assert self.v2.content == "The answer is 42."
class TestMixed:
def setup_method(self):
self.resp = _response(
[
_thinking_block("Planning my approach..."),
_text_block("I'll run the command"),
_tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}),
],
stop_reason="tool_use",
)
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
self.v2 = normalize_anthropic_response_v2(self.resp)
def test_all_fields_present(self):
assert self.v2.content is not None
assert self.v2.tool_calls is not None
assert self.v2.reasoning is not None
assert self.v2.finish_reason == "tool_calls"
def test_content_matches(self):
assert self.v2.content == self.v1_msg.content
def test_reasoning_matches(self):
assert self.v2.reasoning == self.v1_msg.reasoning
def test_tool_call_matches(self):
assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id
assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name
class TestStopReasons:
@pytest.mark.parametrize("stop_reason,expected", [
("end_turn", "stop"),
("tool_use", "tool_calls"),
("max_tokens", "length"),
("stop_sequence", "stop"),
("refusal", "content_filter"),
("model_context_window_exceeded", "length"),
("unknown_future_reason", "stop"),
])
def test_stop_reason_mapping(self, stop_reason, expected):
resp = _response([_text_block("x")], stop_reason=stop_reason)
_v1_msg, v1_finish = normalize_anthropic_response(resp)
v2 = normalize_anthropic_response_v2(resp)
assert v2.finish_reason == v1_finish == expected
class TestStripToolPrefix:
def test_prefix_stripped(self):
resp = _response(
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
stop_reason="tool_use",
)
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True)
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True)
assert v1_msg.tool_calls[0].function.name == "terminal"
assert v2.tool_calls[0].name == "terminal"
def test_prefix_kept(self):
resp = _response(
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
stop_reason="tool_use",
)
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False)
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False)
assert v1_msg.tool_calls[0].function.name == "mcp_terminal"
assert v2.tool_calls[0].name == "mcp_terminal"
class TestEdgeCases:
def test_empty_content_blocks(self):
resp = _response([])
v1_msg, _v1_finish = normalize_anthropic_response(resp)
v2 = normalize_anthropic_response_v2(resp)
assert v2.content == v1_msg.content
assert v2.content is None
def test_no_reasoning_details_means_none_provider_data(self):
resp = _response([_text_block("hi")])
v2 = normalize_anthropic_response_v2(resp)
assert v2.provider_data is None
def test_v2_returns_dataclass_not_namespace(self):
resp = _response([_text_block("hi")])
v2 = normalize_anthropic_response_v2(resp)
assert isinstance(v2, NormalizedResponse)
assert not isinstance(v2, SimpleNamespace)

View File

@@ -0,0 +1,208 @@
"""Tests for the transport ABC, registry, and AnthropicTransport."""
from types import SimpleNamespace
import pytest
from agent.transports import _REGISTRY, get_transport, register_transport
from agent.transports.base import ProviderTransport
from agent.transports.types import NormalizedResponse
class TestProviderTransportABC:
def test_cannot_instantiate_abc(self):
with pytest.raises(TypeError):
ProviderTransport()
def test_concrete_must_implement_all_abstract(self):
class Incomplete(ProviderTransport):
@property
def api_mode(self):
return "test"
with pytest.raises(TypeError):
Incomplete()
def test_minimal_concrete(self):
class Minimal(ProviderTransport):
@property
def api_mode(self):
return "test_minimal"
def convert_messages(self, messages, **kw):
return messages
def convert_tools(self, tools):
return tools
def build_kwargs(self, model, messages, tools=None, **params):
return {"model": model, "messages": messages}
def normalize_response(self, response, **kw):
return NormalizedResponse(content="ok", tool_calls=None, finish_reason="stop")
t = Minimal()
assert t.api_mode == "test_minimal"
assert t.validate_response(None) is True
assert t.extract_cache_stats(None) is None
assert t.map_finish_reason("end_turn") == "end_turn"
class TestTransportRegistry:
def test_get_unregistered_returns_none(self):
assert get_transport("nonexistent_mode") is None
def test_anthropic_registered_on_import(self):
import agent.transports.anthropic # noqa: F401
t = get_transport("anthropic_messages")
assert t is not None
assert t.api_mode == "anthropic_messages"
def test_register_and_get(self):
class DummyTransport(ProviderTransport):
@property
def api_mode(self):
return "dummy_test"
def convert_messages(self, messages, **kw):
return messages
def convert_tools(self, tools):
return tools
def build_kwargs(self, model, messages, tools=None, **params):
return {}
def normalize_response(self, response, **kw):
return NormalizedResponse(content=None, tool_calls=None, finish_reason="stop")
register_transport("dummy_test", DummyTransport)
t = get_transport("dummy_test")
assert t.api_mode == "dummy_test"
_REGISTRY.pop("dummy_test", None)
class TestAnthropicTransport:
@pytest.fixture
def transport(self):
import agent.transports.anthropic # noqa: F401
return get_transport("anthropic_messages")
def test_api_mode(self, transport):
assert transport.api_mode == "anthropic_messages"
def test_convert_tools_simple(self, transport):
tools = [{
"type": "function",
"function": {
"name": "test_tool",
"description": "A test",
"parameters": {"type": "object", "properties": {}},
},
}]
result = transport.convert_tools(tools)
assert len(result) == 1
assert result[0]["name"] == "test_tool"
assert "input_schema" in result[0]
def test_validate_response_none(self, transport):
assert transport.validate_response(None) is False
def test_validate_response_empty_content(self, transport):
r = SimpleNamespace(content=[])
assert transport.validate_response(r) is False
def test_validate_response_valid(self, transport):
r = SimpleNamespace(content=[SimpleNamespace(type="text", text="hello")])
assert transport.validate_response(r) is True
def test_map_finish_reason(self, transport):
assert transport.map_finish_reason("end_turn") == "stop"
assert transport.map_finish_reason("tool_use") == "tool_calls"
assert transport.map_finish_reason("max_tokens") == "length"
assert transport.map_finish_reason("stop_sequence") == "stop"
assert transport.map_finish_reason("refusal") == "content_filter"
assert transport.map_finish_reason("model_context_window_exceeded") == "length"
assert transport.map_finish_reason("unknown") == "stop"
def test_extract_cache_stats_none_usage(self, transport):
r = SimpleNamespace(usage=None)
assert transport.extract_cache_stats(r) is None
def test_extract_cache_stats_with_cache(self, transport):
usage = SimpleNamespace(cache_read_input_tokens=100, cache_creation_input_tokens=50)
r = SimpleNamespace(usage=usage)
result = transport.extract_cache_stats(r)
assert result == {"cached_tokens": 100, "creation_tokens": 50}
def test_extract_cache_stats_zero(self, transport):
usage = SimpleNamespace(cache_read_input_tokens=0, cache_creation_input_tokens=0)
r = SimpleNamespace(usage=usage)
assert transport.extract_cache_stats(r) is None
def test_normalize_response_text(self, transport):
r = SimpleNamespace(
content=[SimpleNamespace(type="text", text="Hello world")],
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert isinstance(nr, NormalizedResponse)
assert nr.content == "Hello world"
assert nr.tool_calls is None or nr.tool_calls == []
assert nr.finish_reason == "stop"
def test_normalize_response_tool_calls(self, transport):
r = SimpleNamespace(
content=[
SimpleNamespace(type="tool_use", id="toolu_123", name="terminal", input={"command": "ls"}),
],
stop_reason="tool_use",
usage=SimpleNamespace(input_tokens=10, output_tokens=20),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert nr.finish_reason == "tool_calls"
assert len(nr.tool_calls) == 1
tc = nr.tool_calls[0]
assert tc.name == "terminal"
assert tc.id == "toolu_123"
assert '"command"' in tc.arguments
def test_normalize_response_thinking(self, transport):
r = SimpleNamespace(
content=[
SimpleNamespace(type="thinking", thinking="Let me think..."),
SimpleNamespace(type="text", text="The answer is 42"),
],
stop_reason="end_turn",
usage=SimpleNamespace(input_tokens=10, output_tokens=15),
model="claude-sonnet-4-6",
)
nr = transport.normalize_response(r)
assert nr.content == "The answer is 42"
assert nr.reasoning == "Let me think..."
def test_build_kwargs_returns_dict(self, transport):
messages = [{"role": "user", "content": "Hello"}]
kw = transport.build_kwargs(
model="claude-sonnet-4-6",
messages=messages,
max_tokens=1024,
)
assert isinstance(kw, dict)
assert "model" in kw
assert "max_tokens" in kw
assert "messages" in kw
def test_convert_messages_extracts_system(self, transport):
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hi"},
]
system, msgs = transport.convert_messages(messages)
assert system is not None
assert len(msgs) >= 1

View File

@@ -0,0 +1,130 @@
"""Tests for agent/transports/types.py — dataclass construction + helpers."""
import json
from agent.transports.types import (
NormalizedResponse,
ToolCall,
Usage,
build_tool_call,
map_finish_reason,
)
class TestToolCall:
def test_basic_construction(self):
tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}')
assert tc.id == "call_abc"
assert tc.name == "terminal"
assert tc.arguments == '{"cmd": "ls"}'
assert tc.provider_data is None
def test_none_id(self):
tc = ToolCall(id=None, name="read_file", arguments="{}")
assert tc.id is None
def test_provider_data(self):
tc = ToolCall(
id="call_x",
name="t",
arguments="{}",
provider_data={"call_id": "call_x", "response_item_id": "fc_x"},
)
assert tc.provider_data["call_id"] == "call_x"
assert tc.provider_data["response_item_id"] == "fc_x"
class TestUsage:
def test_defaults(self):
u = Usage()
assert u.prompt_tokens == 0
assert u.completion_tokens == 0
assert u.total_tokens == 0
assert u.cached_tokens == 0
def test_explicit(self):
u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80)
assert u.total_tokens == 150
class TestNormalizedResponse:
def test_text_only(self):
r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop")
assert r.content == "hello"
assert r.tool_calls is None
assert r.finish_reason == "stop"
assert r.reasoning is None
assert r.usage is None
assert r.provider_data is None
def test_with_tool_calls(self):
tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')]
r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls")
assert r.finish_reason == "tool_calls"
assert len(r.tool_calls) == 1
assert r.tool_calls[0].name == "terminal"
def test_with_reasoning(self):
r = NormalizedResponse(
content="answer",
tool_calls=None,
finish_reason="stop",
reasoning="I thought about it",
)
assert r.reasoning == "I thought about it"
def test_with_provider_data(self):
r = NormalizedResponse(
content=None,
tool_calls=None,
finish_reason="stop",
provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]},
)
assert r.provider_data["reasoning_details"][0]["type"] == "thinking"
class TestBuildToolCall:
def test_dict_arguments_serialized(self):
tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"})
assert tc.arguments == json.dumps({"cmd": "ls"})
assert tc.provider_data is None
def test_string_arguments_passthrough(self):
tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}')
assert tc.arguments == '{"path": "/tmp"}'
def test_provider_fields(self):
tc = build_tool_call(
id="call_3",
name="terminal",
arguments="{}",
call_id="call_3",
response_item_id="fc_3",
)
assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"}
def test_none_id(self):
tc = build_tool_call(id=None, name="t", arguments="{}")
assert tc.id is None
class TestMapFinishReason:
ANTHROPIC_MAP = {
"end_turn": "stop",
"tool_use": "tool_calls",
"max_tokens": "length",
"stop_sequence": "stop",
"refusal": "content_filter",
}
def test_known_reason(self):
assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop"
assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls"
assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length"
assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter"
def test_unknown_reason_defaults_to_stop(self):
assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop"
def test_none_reason(self):
assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"

View File

@@ -1,115 +0,0 @@
"""Tests for Issue #796 tool-calling benchmark coverage and reporting."""
import sys
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import patch
sys.path.insert(0, str(Path(__file__).parent.parent / "benchmarks"))
from tool_call_benchmark import ( # noqa: E402
CallResult,
DEFAULT_COMPARE_MODELS,
ISSUE_796_CATEGORY_COUNTS,
ToolCall,
generate_report,
run_single_test,
suite_category_counts,
)
def test_suite_counts_match_issue_796_distribution():
counts = suite_category_counts()
assert counts == ISSUE_796_CATEGORY_COUNTS
assert sum(counts.values()) == 100
def test_default_compare_models_cover_issue_796_lanes():
assert len(DEFAULT_COMPARE_MODELS) == 3
assert any("gemma-4-31b" in spec for spec in DEFAULT_COMPARE_MODELS)
assert any("gemma-4-26b" in spec for spec in DEFAULT_COMPARE_MODELS)
assert any("mimo-v2-pro" in spec for spec in DEFAULT_COMPARE_MODELS)
def test_generate_report_includes_parallel_and_cost_metrics(tmp_path):
output_path = tmp_path / "report.md"
results = [
CallResult(
test_id="file-01",
category="file",
model="gemma-4-31b",
prompt="Read the file.",
expected_tool="read_file",
success=True,
tool_called="read_file",
schema_ok=True,
tool_args_valid=True,
execution_ok=True,
tool_count=2,
parallel_ok=True,
latency_s=1.25,
total_tokens=123,
estimated_cost_usd=0.0012,
cost_status="estimated",
),
CallResult(
test_id="web-01",
category="web",
model="mimo-v2-pro",
prompt="Search the web.",
expected_tool="web_search",
success=False,
tool_called="web_search",
schema_ok=True,
tool_args_valid=False,
execution_ok=False,
tool_count=1,
parallel_ok=False,
latency_s=2.5,
error="bad args",
total_tokens=456,
estimated_cost_usd=None,
cost_status="unknown",
skipped=True,
skip_reason="web_search unavailable",
),
]
report = generate_report(results, ["gemma-4-31b", "mimo-v2-pro"], output_path)
assert output_path.exists()
assert "Parallel tool success" in report
assert "Avg token cost per call (USD)" in report
assert "Skipped / unavailable" in report
assert "Requested category mix" in report
def test_run_single_test_skips_when_expected_tool_unavailable():
class FakeAgent:
def __init__(self, *args, **kwargs):
self.valid_tool_names = {"read_file", "terminal"}
self.session_input_tokens = 0
self.session_output_tokens = 0
self.session_cache_read_tokens = 0
self.session_cache_write_tokens = 0
self.session_api_calls = 0
self.base_url = ""
self.api_key = None
def run_conversation(self, *args, **kwargs):
raise AssertionError("run_conversation should not be called for unavailable tools")
tc = ToolCall(
id="mcp-01",
category="mcp",
prompt="Use an MCP tool to list resources.",
expected_tool="",
expected_tool_prefix="mcp_",
)
with patch.dict(sys.modules, {"run_agent": SimpleNamespace(AIAgent=FakeAgent)}):
result = run_single_test(tc, "gemini:gemma-4-31b-it", "gemini")
assert result.skipped is True
assert "mcp_" in result.skip_reason
assert result.success is False