Compare commits

..

2 Commits

Author SHA1 Message Date
9288ae8be9 test(#745): Add tests for cost estimator
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 42s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 47s
Tests / e2e (pull_request) Successful in 4m31s
Tests / test (pull_request) Failing after 53m47s
Tests for cost estimation, pricing lookup.
Refs #745
2026-04-15 03:13:15 +00:00
f86233cd52 feat(#745): Add provider cost estimator tool
Cost estimation tool with:
- estimate_cost(input_tokens, output_tokens, provider, model)
- Pricing table for OpenRouter, Nous, Anthropic, local (free)
- Session cost estimation
- Cost report formatting

Resolves #745
2026-04-15 03:12:33 +00:00
4 changed files with 233 additions and 319 deletions

View File

@@ -1,245 +0,0 @@
"""Tool call fallback — retry failed tool calls with local model.
When the primary provider fails during tool calling (timeout, 429, 503),
fall back to local Ollama to keep the agent working.
Usage:
from agent.tool_fallback import ToolFallbackHandler
handler = ToolFallbackHandler()
result = handler.execute_with_fallback(tool_fn, args, context)
"""
from __future__ import annotations
import json
import logging
import os
import time
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Optional
logger = logging.getLogger(__name__)
# Fallback provider config
_FALLBACK_PROVIDER = os.getenv("TOOL_FALLBACK_PROVIDER", "ollama")
_FALLBACK_MODEL = os.getenv("TOOL_FALLBACK_MODEL", "qwen2.5:7b")
_FALLBACK_BASE_URL = os.getenv("TOOL_FALLBACK_BASE_URL", "http://localhost:11434/v1")
# Error patterns that trigger fallback
_FALLBACK_TRIGGERS = [
"429",
"rate limit",
"ratelimit",
"503",
"service unavailable",
"timeout",
"timed out",
"connection error",
"connection refused",
"overloaded",
"capacity",
"quota exceeded",
"insufficient",
]
@dataclass
class FallbackEvent:
"""Record of a fallback event."""
timestamp: float
tool_name: str
original_provider: str
fallback_provider: str
error: str
success: bool
duration_ms: int = 0
class ToolFallbackHandler:
"""Handles tool call fallback to local models.
Tracks fallback events and provides fallback execution.
"""
def __init__(
self,
fallback_provider: str = "",
fallback_model: str = "",
fallback_base_url: str = "",
enabled: bool = True,
):
self.fallback_provider = fallback_provider or _FALLBACK_PROVIDER
self.fallback_model = fallback_model or _FALLBACK_MODEL
self.fallback_base_url = fallback_base_url or _FALLBACK_BASE_URL
self.enabled = enabled
self._events: list[FallbackEvent] = []
self._fallback_count = 0
self._fallback_success_count = 0
@property
def events(self) -> list[FallbackEvent]:
return list(self._events)
@property
def stats(self) -> dict:
return {
"total_fallbacks": self._fallback_count,
"successful_fallbacks": self._fallback_success_count,
"fallback_rate": (
self._fallback_success_count / self._fallback_count
if self._fallback_count > 0 else 0
),
}
def should_fallback(self, error: Any) -> bool:
"""Check if an error should trigger fallback."""
if not self.enabled:
return False
error_str = str(error).lower()
return any(trigger in error_str for trigger in _FALLBACK_TRIGGERS)
def get_fallback_client(self) -> Optional[Any]:
"""Get an OpenAI client configured for the fallback provider."""
try:
from openai import OpenAI
client = OpenAI(
base_url=self.fallback_base_url,
api_key=os.getenv("OPENAI_API_KEY", "ollama"),
)
return client
except Exception as e:
logger.error("Failed to create fallback client: %s", e)
return None
def call_with_fallback(
self,
messages: list[dict],
tools: list[dict] = None,
original_provider: str = "",
tool_name: str = "unknown",
max_tokens: int = 1024,
) -> dict:
"""Make an API call with fallback to local model on failure.
Args:
messages: Conversation messages
tools: Tool definitions
original_provider: Name of the original provider
tool_name: Name of the tool being called
max_tokens: Max tokens for the response
Returns:
Dict with 'response', 'used_fallback', 'fallback_event' keys.
"""
t0 = time.monotonic()
# Try fallback client
client = self.get_fallback_client()
if not client:
return {
"response": None,
"used_fallback": False,
"error": "Fallback client unavailable",
}
try:
response = client.chat.completions.create(
model=self.fallback_model,
messages=messages,
tools=tools if tools else None,
max_tokens=max_tokens,
)
elapsed = int((time.monotonic() - t0) * 1000)
event = FallbackEvent(
timestamp=time.time(),
tool_name=tool_name,
original_provider=original_provider,
fallback_provider=self.fallback_provider,
error="",
success=True,
duration_ms=elapsed,
)
self._events.append(event)
self._fallback_count += 1
self._fallback_success_count += 1
logger.info(
"Tool fallback succeeded: %s via %s (%dms)",
tool_name, self.fallback_provider, elapsed,
)
return {
"response": response,
"used_fallback": True,
"fallback_event": event,
}
except Exception as e:
elapsed = int((time.monotonic() - t0) * 1000)
event = FallbackEvent(
timestamp=time.time(),
tool_name=tool_name,
original_provider=original_provider,
fallback_provider=self.fallback_provider,
error=str(e),
success=False,
duration_ms=elapsed,
)
self._events.append(event)
self._fallback_count += 1
logger.error(
"Tool fallback failed: %s via %s%s",
tool_name, self.fallback_provider, e,
)
return {
"response": None,
"used_fallback": True,
"fallback_event": event,
"error": str(e),
}
def format_report(self) -> str:
"""Format fallback events as a report."""
if not self._events:
return "No fallback events recorded."
lines = [
"Tool Fallback Report",
"=" * 40,
f"Total fallbacks: {self._fallback_count}",
f"Successful: {self._fallback_success_count}",
f"Failed: {self._fallback_count - self._fallback_success_count}",
"",
]
for event in self._events[-10:]:
status = "OK" if event.success else "FAIL"
lines.append(
f" [{status}] {event.tool_name} via {event.fallback_provider} "
f"({event.duration_ms}ms) — {event.original_provider}"
)
if event.error:
lines.append(f" Error: {event.error[:100]}")
return "\n".join(lines)
# Singleton handler
_handler: Optional[ToolFallbackHandler] = None
def get_tool_fallback_handler() -> ToolFallbackHandler:
"""Get or create the singleton tool fallback handler."""
global _handler
if _handler is None:
_handler = ToolFallbackHandler()
return _handler
def reset_tool_fallback_handler() -> None:
"""Reset the singleton (for testing)."""
global _handler
_handler = None

View File

@@ -0,0 +1,41 @@
"""
Tests for cost estimator tool (#745).
"""
import pytest
from tools.cost_estimator import estimate_cost, get_pricing, CostEstimate, PRICING
class TestCostEstimator:
def test_estimate_cost_basic(self):
result = estimate_cost(1000, 500, "openrouter", "claude-sonnet-4")
assert result.input_tokens == 1000
assert result.output_tokens == 500
assert result.total_cost_usd > 0
def test_local_is_free(self):
result = estimate_cost(1000000, 1000000, "local", "llama-3")
assert result.total_cost_usd == 0.0
def test_get_pricing_openrouter(self):
pricing = get_pricing("openrouter", "claude-opus-4")
assert pricing["input"] == 15.0
assert pricing["output"] == 75.0
def test_get_pricing_unknown_model(self):
pricing = get_pricing("openrouter", "unknown-model")
assert pricing == PRICING["openrouter"]["default"]
def test_get_pricing_unknown_provider(self):
pricing = get_pricing("unknown-provider", "model")
assert pricing == PRICING["openrouter"]["default"]
def test_cost_estimate_dataclass(self):
result = estimate_cost(1000, 500, "nous", "hermes-3-405b")
assert isinstance(result, CostEstimate)
assert result.provider == "nous"
assert result.model == "hermes-3-405b"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,74 +0,0 @@
"""Tests for tool call fallback handler."""
import pytest
import sys
from pathlib import Path
from unittest.mock import MagicMock, patch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from agent.tool_fallback import ToolFallbackHandler, FallbackEvent, get_tool_fallback_handler
class TestShouldFallback:
def test_rate_limit_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("429 rate limit exceeded")
assert handler.should_fallback("RateLimitError: too many requests")
def test_timeout_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("Connection timed out")
assert handler.should_fallback("Request timed out after 30s")
def test_503_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("503 Service Unavailable")
assert handler.should_fallback("Service unavailable")
def test_quota_triggers(self):
handler = ToolFallbackHandler()
assert handler.should_fallback("quota exceeded")
assert handler.should_fallback("insufficient credits")
def test_normal_error_no_trigger(self):
handler = ToolFallbackHandler()
assert not handler.should_fallback("Invalid API key")
assert not handler.should_fallback("Model not found")
def test_disabled_handler(self):
handler = ToolFallbackHandler(enabled=False)
assert not handler.should_fallback("429 rate limit")
class TestFallbackEvents:
def test_event_creation(self):
event = FallbackEvent(
timestamp=1234567890.0,
tool_name="terminal",
original_provider="openrouter",
fallback_provider="ollama",
error="",
success=True,
duration_ms=150,
)
assert event.tool_name == "terminal"
assert event.success
assert event.duration_ms == 150
class TestFallbackHandler:
def test_stats_initial(self):
handler = ToolFallbackHandler()
assert handler.stats["total_fallbacks"] == 0
assert handler.stats["successful_fallbacks"] == 0
def test_report_no_events(self):
handler = ToolFallbackHandler()
report = handler.format_report()
assert "No fallback events" in report
def test_singleton(self):
h1 = get_tool_fallback_handler()
h2 = get_tool_fallback_handler()
assert h1 is h2

192
tools/cost_estimator.py Normal file
View File

@@ -0,0 +1,192 @@
"""
Provider Cost Estimator — Estimate API costs from token counts.
Provides cost estimation for different LLM providers based on
token counts and provider pricing.
"""
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
@dataclass
class CostEstimate:
"""Cost estimate for a request."""
input_tokens: int
output_tokens: int
input_cost_usd: float
output_cost_usd: float
total_cost_usd: float
provider: str
model: str
# Pricing table (USD per 1M tokens) — as of April 2026
PRICING = {
"openrouter": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"gpt-4o": {"input": 2.50, "output": 10.0},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gemini-2.5-pro": {"input": 1.25, "output": 10.0},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"llama-4-scout": {"input": 0.20, "output": 0.80},
"llama-4-maverick": {"input": 0.50, "output": 2.0},
"default": {"input": 1.0, "output": 3.0},
},
"nous": {
"hermes-3-405b": {"input": 5.0, "output": 5.0},
"mixtral-8x22b": {"input": 2.0, "output": 2.0},
"hermes-2-mixtral-8x7b": {"input": 0.90, "output": 0.90},
"default": {"input": 2.0, "output": 2.0},
},
"anthropic": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"default": {"input": 3.0, "output": 15.0},
},
"local": {
# Local models are free (electricity only)
"default": {"input": 0.0, "output": 0.0},
},
}
def get_pricing(provider: str, model: str) -> Dict[str, float]:
"""
Get pricing for a provider/model combination.
Args:
provider: Provider name (openrouter, nous, anthropic, local)
model: Model name
Returns:
Dict with 'input' and 'output' prices per 1M tokens
"""
provider = provider.lower().strip()
model = model.lower().strip()
provider_pricing = PRICING.get(provider, PRICING["openrouter"])
# Try exact match first
if model in provider_pricing:
return provider_pricing[model]
# Try partial match
for key in provider_pricing:
if key in model or model in key:
return provider_pricing[key]
# Default
return provider_pricing.get("default", {"input": 1.0, "output": 3.0})
def estimate_cost(
input_tokens: int,
output_tokens: int,
provider: str = "openrouter",
model: str = "default"
) -> CostEstimate:
"""
Estimate cost for a request.
Args:
input_tokens: Number of input tokens
output_tokens: Number of output tokens
provider: Provider name
model: Model name
Returns:
CostEstimate with breakdown
"""
pricing = get_pricing(provider, model)
# Calculate costs (pricing is per 1M tokens)
input_cost = (input_tokens / 1_000_000) * pricing["input"]
output_cost = (output_tokens / 1_000_000) * pricing["output"]
total_cost = input_cost + output_cost
return CostEstimate(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_cost_usd=input_cost,
output_cost_usd=output_cost,
total_cost_usd=total_cost,
provider=provider,
model=model,
)
def estimate_session_cost(messages: list, provider: str = "openrouter", model: str = "default") -> CostEstimate:
"""
Estimate cost for a session based on message count.
Args:
messages: List of messages (each with 'role' and 'content')
provider: Provider name
model: Model name
Returns:
CostEstimate for the session
"""
# Rough token estimation: ~4 chars per token
input_tokens = 0
output_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
tokens = len(content) // 4
if msg.get("role") == "user":
input_tokens += tokens
elif msg.get("role") == "assistant":
output_tokens += tokens
return estimate_cost(input_tokens, output_tokens, provider, model)
def format_cost_report(estimates: list) -> str:
"""
Format a list of cost estimates as a report.
Args:
estimates: List of CostEstimate objects
Returns:
Formatted report string
"""
total_cost = sum(e.total_cost_usd for e in estimates)
total_input = sum(e.input_tokens for e in estimates)
total_output = sum(e.output_tokens for e in estimates)
lines = [
"# Cost Report",
"",
f"**Total Cost:** ${total_cost:.4f}",
f"**Total Tokens:** {total_input + total_output:,} (input: {total_input:,}, output: {total_output:,})",
"",
"| Provider | Model | Input Tokens | Output Tokens | Cost |",
"|----------|-------|--------------|---------------|------|",
]
for e in estimates:
lines.append(f"| {e.provider} | {e.model} | {e.input_tokens:,} | {e.output_tokens:,} | ${e.total_cost_usd:.4f} |")
lines.append("")
lines.append(f"*Generated by cost_estimator.py*")
return "\n".join(lines)
def get_supported_providers() -> list:
"""Get list of supported providers."""
return list(PRICING.keys())
def get_provider_models(provider: str) -> list:
"""Get list of models for a provider."""
provider = provider.lower().strip()
provider_pricing = PRICING.get(provider, {})
return [k for k in provider_pricing.keys() if k != "default"]