Compare commits

..

2 Commits

Author SHA1 Message Date
9288ae8be9 test(#745): Add tests for cost estimator
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 42s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 47s
Tests / e2e (pull_request) Successful in 4m31s
Tests / test (pull_request) Failing after 53m47s
Tests for cost estimation, pricing lookup.
Refs #745
2026-04-15 03:13:15 +00:00
f86233cd52 feat(#745): Add provider cost estimator tool
Cost estimation tool with:
- estimate_cost(input_tokens, output_tokens, provider, model)
- Pricing table for OpenRouter, Nous, Anthropic, local (free)
- Session cost estimation
- Cost report formatting

Resolves #745
2026-04-15 03:12:33 +00:00
4 changed files with 233 additions and 383 deletions

View File

@@ -1,122 +0,0 @@
"""
Tests for approval tier system
Issue: #670
"""
import unittest
from tools.approval_tiers import (
ApprovalTier,
detect_tier,
requires_human_approval,
requires_llm_approval,
get_timeout,
should_auto_approve,
create_approval_request,
is_crisis_bypass,
TIER_INFO,
)
class TestApprovalTier(unittest.TestCase):
def test_tier_values(self):
self.assertEqual(ApprovalTier.SAFE, 0)
self.assertEqual(ApprovalTier.LOW, 1)
self.assertEqual(ApprovalTier.MEDIUM, 2)
self.assertEqual(ApprovalTier.HIGH, 3)
self.assertEqual(ApprovalTier.CRITICAL, 4)
class TestTierDetection(unittest.TestCase):
def test_safe_actions(self):
self.assertEqual(detect_tier("read_file"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("web_search"), ApprovalTier.SAFE)
self.assertEqual(detect_tier("session_search"), ApprovalTier.SAFE)
def test_low_actions(self):
self.assertEqual(detect_tier("write_file"), ApprovalTier.LOW)
self.assertEqual(detect_tier("terminal"), ApprovalTier.LOW)
self.assertEqual(detect_tier("execute_code"), ApprovalTier.LOW)
def test_medium_actions(self):
self.assertEqual(detect_tier("send_message"), ApprovalTier.MEDIUM)
self.assertEqual(detect_tier("git_push"), ApprovalTier.MEDIUM)
def test_high_actions(self):
self.assertEqual(detect_tier("config_change"), ApprovalTier.HIGH)
self.assertEqual(detect_tier("key_rotation"), ApprovalTier.HIGH)
def test_critical_actions(self):
self.assertEqual(detect_tier("kill_process"), ApprovalTier.CRITICAL)
self.assertEqual(detect_tier("shutdown"), ApprovalTier.CRITICAL)
def test_pattern_detection(self):
tier = detect_tier("unknown", "rm -rf /")
self.assertEqual(tier, ApprovalTier.CRITICAL)
tier = detect_tier("unknown", "sudo apt install")
self.assertEqual(tier, ApprovalTier.MEDIUM)
class TestTierInfo(unittest.TestCase):
def test_safe_no_approval(self):
self.assertFalse(requires_human_approval(ApprovalTier.SAFE))
self.assertFalse(requires_llm_approval(ApprovalTier.SAFE))
self.assertIsNone(get_timeout(ApprovalTier.SAFE))
def test_medium_requires_both(self):
self.assertTrue(requires_human_approval(ApprovalTier.MEDIUM))
self.assertTrue(requires_llm_approval(ApprovalTier.MEDIUM))
self.assertEqual(get_timeout(ApprovalTier.MEDIUM), 60)
def test_critical_fast_timeout(self):
self.assertEqual(get_timeout(ApprovalTier.CRITICAL), 10)
class TestAutoApprove(unittest.TestCase):
def test_safe_auto_approves(self):
self.assertTrue(should_auto_approve("read_file"))
self.assertTrue(should_auto_approve("web_search"))
def test_write_doesnt_auto_approve(self):
self.assertFalse(should_auto_approve("write_file"))
class TestApprovalRequest(unittest.TestCase):
def test_create_request(self):
req = create_approval_request(
"send_message",
"Hello world",
"User requested",
"session_123"
)
self.assertEqual(req.tier, ApprovalTier.MEDIUM)
self.assertEqual(req.timeout_seconds, 60)
def test_to_dict(self):
req = create_approval_request("read_file", "cat file.txt", "test", "s1")
d = req.to_dict()
self.assertEqual(d["tier"], 0)
self.assertEqual(d["tier_name"], "Safe")
class TestCrisisBypass(unittest.TestCase):
def test_send_message_bypass(self):
self.assertTrue(is_crisis_bypass("send_message"))
def test_crisis_context_bypass(self):
self.assertTrue(is_crisis_bypass("unknown", "call 988 lifeline"))
self.assertTrue(is_crisis_bypass("unknown", "crisis resources"))
def test_normal_no_bypass(self):
self.assertFalse(is_crisis_bypass("read_file"))
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,41 @@
"""
Tests for cost estimator tool (#745).
"""
import pytest
from tools.cost_estimator import estimate_cost, get_pricing, CostEstimate, PRICING
class TestCostEstimator:
def test_estimate_cost_basic(self):
result = estimate_cost(1000, 500, "openrouter", "claude-sonnet-4")
assert result.input_tokens == 1000
assert result.output_tokens == 500
assert result.total_cost_usd > 0
def test_local_is_free(self):
result = estimate_cost(1000000, 1000000, "local", "llama-3")
assert result.total_cost_usd == 0.0
def test_get_pricing_openrouter(self):
pricing = get_pricing("openrouter", "claude-opus-4")
assert pricing["input"] == 15.0
assert pricing["output"] == 75.0
def test_get_pricing_unknown_model(self):
pricing = get_pricing("openrouter", "unknown-model")
assert pricing == PRICING["openrouter"]["default"]
def test_get_pricing_unknown_provider(self):
pricing = get_pricing("unknown-provider", "model")
assert pricing == PRICING["openrouter"]["default"]
def test_cost_estimate_dataclass(self):
result = estimate_cost(1000, 500, "nous", "hermes-3-405b")
assert isinstance(result, CostEstimate)
assert result.provider == "nous"
assert result.model == "hermes-3-405b"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,261 +0,0 @@
"""
Approval Tier System — Graduated safety based on risk level
Extends approval.py with 5-tier system for command approval.
| Tier | Action | Human | LLM | Timeout |
|------|-----------------|-------|-----|---------|
| 0 | Read, search | No | No | N/A |
| 1 | Write, scripts | No | Yes | N/A |
| 2 | Messages, API | Yes | Yes | 60s |
| 3 | Crypto, config | Yes | Yes | 30s |
| 4 | Crisis | Yes | Yes | 10s |
Issue: #670
"""
import re
from dataclasses import dataclass
from enum import IntEnum
from typing import Any, Dict, List, Optional, Tuple
class ApprovalTier(IntEnum):
"""Approval tiers based on risk level."""
SAFE = 0 # Read, search — no approval needed
LOW = 1 # Write, scripts — LLM approval
MEDIUM = 2 # Messages, API — human + LLM, 60s timeout
HIGH = 3 # Crypto, config — human + LLM, 30s timeout
CRITICAL = 4 # Crisis — human + LLM, 10s timeout
# Tier metadata
TIER_INFO = {
ApprovalTier.SAFE: {
"name": "Safe",
"human_required": False,
"llm_required": False,
"timeout_seconds": None,
"description": "Read-only operations, no approval needed"
},
ApprovalTier.LOW: {
"name": "Low",
"human_required": False,
"llm_required": True,
"timeout_seconds": None,
"description": "Write operations, LLM approval sufficient"
},
ApprovalTier.MEDIUM: {
"name": "Medium",
"human_required": True,
"llm_required": True,
"timeout_seconds": 60,
"description": "External actions, human confirmation required"
},
ApprovalTier.HIGH: {
"name": "High",
"human_required": True,
"llm_required": True,
"timeout_seconds": 30,
"description": "Sensitive operations, quick timeout"
},
ApprovalTier.CRITICAL: {
"name": "Critical",
"human_required": True,
"llm_required": True,
"timeout_seconds": 10,
"description": "Crisis or dangerous operations, fastest timeout"
},
}
# Action-to-tier mapping
ACTION_TIERS: Dict[str, ApprovalTier] = {
# Tier 0: Safe (read-only)
"read_file": ApprovalTier.SAFE,
"search_files": ApprovalTier.SAFE,
"web_search": ApprovalTier.SAFE,
"session_search": ApprovalTier.SAFE,
"list_files": ApprovalTier.SAFE,
"get_file_content": ApprovalTier.SAFE,
"memory_search": ApprovalTier.SAFE,
"skills_list": ApprovalTier.SAFE,
"skills_search": ApprovalTier.SAFE,
# Tier 1: Low (write operations)
"write_file": ApprovalTier.LOW,
"create_file": ApprovalTier.LOW,
"patch_file": ApprovalTier.LOW,
"delete_file": ApprovalTier.LOW,
"execute_code": ApprovalTier.LOW,
"terminal": ApprovalTier.LOW,
"run_script": ApprovalTier.LOW,
"skill_install": ApprovalTier.LOW,
# Tier 2: Medium (external actions)
"send_message": ApprovalTier.MEDIUM,
"web_fetch": ApprovalTier.MEDIUM,
"browser_navigate": ApprovalTier.MEDIUM,
"api_call": ApprovalTier.MEDIUM,
"gitea_create_issue": ApprovalTier.MEDIUM,
"gitea_create_pr": ApprovalTier.MEDIUM,
"git_push": ApprovalTier.MEDIUM,
"deploy": ApprovalTier.MEDIUM,
# Tier 3: High (sensitive operations)
"config_change": ApprovalTier.HIGH,
"env_change": ApprovalTier.HIGH,
"key_rotation": ApprovalTier.HIGH,
"access_grant": ApprovalTier.HIGH,
"permission_change": ApprovalTier.HIGH,
"backup_restore": ApprovalTier.HIGH,
# Tier 4: Critical (crisis/dangerous)
"kill_process": ApprovalTier.CRITICAL,
"rm_rf": ApprovalTier.CRITICAL,
"format_disk": ApprovalTier.CRITICAL,
"shutdown": ApprovalTier.CRITICAL,
"crisis_override": ApprovalTier.CRITICAL,
}
# Dangerous command patterns (from existing approval.py)
_DANGEROUS_PATTERNS = [
(r"rm\s+-rf\s+/", ApprovalTier.CRITICAL),
(r"mkfs\.", ApprovalTier.CRITICAL),
(r"dd\s+if=.*of=/dev/", ApprovalTier.CRITICAL),
(r"shutdown|reboot|halt", ApprovalTier.CRITICAL),
(r"chmod\s+777", ApprovalTier.HIGH),
(r"curl.*\|\s*bash", ApprovalTier.HIGH),
(r"wget.*\|\s*sh", ApprovalTier.HIGH),
(r"eval\s*\(", ApprovalTier.HIGH),
(r"sudo\s+", ApprovalTier.MEDIUM),
(r"git\s+push.*--force", ApprovalTier.HIGH),
(r"docker\s+rm.*-f", ApprovalTier.MEDIUM),
(r"kubectl\s+delete", ApprovalTier.HIGH),
]
@dataclass
class ApprovalRequest:
"""A request for approval."""
action: str
tier: ApprovalTier
command: str
reason: str
session_key: str
timeout_seconds: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
return {
"action": self.action,
"tier": self.tier.value,
"tier_name": TIER_INFO[self.tier]["name"],
"command": self.command,
"reason": self.reason,
"session_key": self.session_key,
"timeout": self.timeout_seconds,
"human_required": TIER_INFO[self.tier]["human_required"],
"llm_required": TIER_INFO[self.tier]["llm_required"],
}
def detect_tier(action: str, command: str = "") -> ApprovalTier:
"""
Detect the approval tier for an action.
Checks action name first, then falls back to pattern matching.
"""
# Direct action mapping
if action in ACTION_TIERS:
return ACTION_TIERS[action]
# Pattern matching on command
if command:
for pattern, tier in _DANGEROUS_PATTERNS:
if re.search(pattern, command, re.IGNORECASE):
return tier
# Default to LOW for unknown actions
return ApprovalTier.LOW
def requires_human_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires human approval."""
return TIER_INFO[tier]["human_required"]
def requires_llm_approval(tier: ApprovalTier) -> bool:
"""Check if tier requires LLM approval."""
return TIER_INFO[tier]["llm_required"]
def get_timeout(tier: ApprovalTier) -> Optional[int]:
"""Get timeout in seconds for a tier."""
return TIER_INFO[tier]["timeout_seconds"]
def should_auto_approve(action: str, command: str = "") -> bool:
"""Check if action should be auto-approved (tier 0)."""
tier = detect_tier(action, command)
return tier == ApprovalTier.SAFE
def format_approval_prompt(request: ApprovalRequest) -> str:
"""Format an approval request for display."""
info = TIER_INFO[request.tier]
lines = []
lines.append(f"⚠️ Approval Required (Tier {request.tier.value}: {info['name']})")
lines.append(f"")
lines.append(f"Action: {request.action}")
lines.append(f"Command: {request.command[:100]}{'...' if len(request.command) > 100 else ''}")
lines.append(f"Reason: {request.reason}")
lines.append(f"")
if info["human_required"]:
lines.append(f"👤 Human approval required")
if info["llm_required"]:
lines.append(f"🤖 LLM approval required")
if info["timeout_seconds"]:
lines.append(f"⏱️ Timeout: {info['timeout_seconds']}s")
return "\n".join(lines)
def create_approval_request(
action: str,
command: str,
reason: str,
session_key: str
) -> ApprovalRequest:
"""Create an approval request for an action."""
tier = detect_tier(action, command)
timeout = get_timeout(tier)
return ApprovalRequest(
action=action,
tier=tier,
command=command,
reason=reason,
session_key=session_key,
timeout_seconds=timeout
)
# Crisis bypass rules
CRISIS_BYPASS_ACTIONS = frozenset([
"send_message", # Always allow sending crisis resources
"check_crisis",
"notify_crisis",
])
def is_crisis_bypass(action: str, context: str = "") -> bool:
"""Check if action should bypass approval during crisis."""
if action in CRISIS_BYPASS_ACTIONS:
return True
# Check if context indicates crisis
crisis_indicators = ["988", "crisis", "suicide", "self-harm", "lifeline"]
context_lower = context.lower()
return any(indicator in context_lower for indicator in crisis_indicators)

192
tools/cost_estimator.py Normal file
View File

@@ -0,0 +1,192 @@
"""
Provider Cost Estimator — Estimate API costs from token counts.
Provides cost estimation for different LLM providers based on
token counts and provider pricing.
"""
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
@dataclass
class CostEstimate:
"""Cost estimate for a request."""
input_tokens: int
output_tokens: int
input_cost_usd: float
output_cost_usd: float
total_cost_usd: float
provider: str
model: str
# Pricing table (USD per 1M tokens) — as of April 2026
PRICING = {
"openrouter": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"gpt-4o": {"input": 2.50, "output": 10.0},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gemini-2.5-pro": {"input": 1.25, "output": 10.0},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"llama-4-scout": {"input": 0.20, "output": 0.80},
"llama-4-maverick": {"input": 0.50, "output": 2.0},
"default": {"input": 1.0, "output": 3.0},
},
"nous": {
"hermes-3-405b": {"input": 5.0, "output": 5.0},
"mixtral-8x22b": {"input": 2.0, "output": 2.0},
"hermes-2-mixtral-8x7b": {"input": 0.90, "output": 0.90},
"default": {"input": 2.0, "output": 2.0},
},
"anthropic": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"default": {"input": 3.0, "output": 15.0},
},
"local": {
# Local models are free (electricity only)
"default": {"input": 0.0, "output": 0.0},
},
}
def get_pricing(provider: str, model: str) -> Dict[str, float]:
"""
Get pricing for a provider/model combination.
Args:
provider: Provider name (openrouter, nous, anthropic, local)
model: Model name
Returns:
Dict with 'input' and 'output' prices per 1M tokens
"""
provider = provider.lower().strip()
model = model.lower().strip()
provider_pricing = PRICING.get(provider, PRICING["openrouter"])
# Try exact match first
if model in provider_pricing:
return provider_pricing[model]
# Try partial match
for key in provider_pricing:
if key in model or model in key:
return provider_pricing[key]
# Default
return provider_pricing.get("default", {"input": 1.0, "output": 3.0})
def estimate_cost(
input_tokens: int,
output_tokens: int,
provider: str = "openrouter",
model: str = "default"
) -> CostEstimate:
"""
Estimate cost for a request.
Args:
input_tokens: Number of input tokens
output_tokens: Number of output tokens
provider: Provider name
model: Model name
Returns:
CostEstimate with breakdown
"""
pricing = get_pricing(provider, model)
# Calculate costs (pricing is per 1M tokens)
input_cost = (input_tokens / 1_000_000) * pricing["input"]
output_cost = (output_tokens / 1_000_000) * pricing["output"]
total_cost = input_cost + output_cost
return CostEstimate(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_cost_usd=input_cost,
output_cost_usd=output_cost,
total_cost_usd=total_cost,
provider=provider,
model=model,
)
def estimate_session_cost(messages: list, provider: str = "openrouter", model: str = "default") -> CostEstimate:
"""
Estimate cost for a session based on message count.
Args:
messages: List of messages (each with 'role' and 'content')
provider: Provider name
model: Model name
Returns:
CostEstimate for the session
"""
# Rough token estimation: ~4 chars per token
input_tokens = 0
output_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
tokens = len(content) // 4
if msg.get("role") == "user":
input_tokens += tokens
elif msg.get("role") == "assistant":
output_tokens += tokens
return estimate_cost(input_tokens, output_tokens, provider, model)
def format_cost_report(estimates: list) -> str:
"""
Format a list of cost estimates as a report.
Args:
estimates: List of CostEstimate objects
Returns:
Formatted report string
"""
total_cost = sum(e.total_cost_usd for e in estimates)
total_input = sum(e.input_tokens for e in estimates)
total_output = sum(e.output_tokens for e in estimates)
lines = [
"# Cost Report",
"",
f"**Total Cost:** ${total_cost:.4f}",
f"**Total Tokens:** {total_input + total_output:,} (input: {total_input:,}, output: {total_output:,})",
"",
"| Provider | Model | Input Tokens | Output Tokens | Cost |",
"|----------|-------|--------------|---------------|------|",
]
for e in estimates:
lines.append(f"| {e.provider} | {e.model} | {e.input_tokens:,} | {e.output_tokens:,} | ${e.total_cost_usd:.4f} |")
lines.append("")
lines.append(f"*Generated by cost_estimator.py*")
return "\n".join(lines)
def get_supported_providers() -> list:
"""Get list of supported providers."""
return list(PRICING.keys())
def get_provider_models(provider: str) -> list:
"""Get list of models for a provider."""
provider = provider.lower().strip()
provider_pricing = PRICING.get(provider, {})
return [k for k in provider_pricing.keys() if k != "default"]