Compare commits

..

2 Commits

Author SHA1 Message Date
9288ae8be9 test(#745): Add tests for cost estimator
Some checks failed
Contributor Attribution Check / check-attribution (pull_request) Failing after 42s
Docker Build and Publish / build-and-push (pull_request) Has been skipped
Supply Chain Audit / Scan PR for supply chain risks (pull_request) Successful in 47s
Tests / e2e (pull_request) Successful in 4m31s
Tests / test (pull_request) Failing after 53m47s
Tests for cost estimation, pricing lookup.
Refs #745
2026-04-15 03:13:15 +00:00
f86233cd52 feat(#745): Add provider cost estimator tool
Cost estimation tool with:
- estimate_cost(input_tokens, output_tokens, provider, model)
- Pricing table for OpenRouter, Nous, Anthropic, local (free)
- Session cost estimation
- Cost report formatting

Resolves #745
2026-04-15 03:12:33 +00:00
4 changed files with 233 additions and 362 deletions

View File

@@ -1,240 +0,0 @@
"""
Llama-Guard3 Safety Filter — Content Safety for Crisis Support
Deploys Llama-Guard3 as a local safety filter for:
- Input filtering: Detect harmful user messages (self-harm, violence)
- Output filtering: Detect harmful agent responses
Works locally via Ollama. No external API calls. Privacy preserved.
Issue: #669
"""
import json
import logging
import time
import urllib.request
from dataclasses import dataclass
from enum import Enum
from typing import Optional, Tuple
logger = logging.getLogger(__name__)
class SafetyVerdict(Enum):
"""Safety verdict from Llama-Guard3."""
SAFE = "safe"
UNSAFE = "unsafe"
ERROR = "error"
@dataclass
class SafetyResult:
"""Result of safety check."""
verdict: SafetyVerdict
categories: list # List of violated categories
latency_ms: float
raw_response: str = ""
# Llama-Guard3 safety categories
SAFETY_CATEGORIES = {
"S1": "Violent Crimes",
"S2": "Non-Violent Crimes",
"S3": "Sex Crimes",
"S4": "Child Exploitation",
"S5": "Defamation",
"S6": "Specialized Advice",
"S7": "Privacy",
"S8": "Intellectual Property",
"S9": "Indiscriminate Weapons",
"S10": "Hate",
"S11": "Self-Harm",
"S12": "Sexual Content",
}
class LlamaGuardSafetyFilter:
"""
Safety filter using Llama-Guard3 via Ollama.
Usage:
filter = LlamaGuardSafetyFilter()
# Check user input
result = filter.check_input("I want to hurt myself")
if result.verdict == SafetyVerdict.UNSAFE:
return filter.get_crisis_response(result)
# Check agent output
result = filter.check_output(response_text)
if result.verdict == SafetyVerdict.UNSAFE:
return filter.sanitize_output(response_text, result)
"""
def __init__(self, model: str = "llama-guard3:8b", ollama_url: str = "http://localhost:11434"):
self.model = model
self.ollama_url = ollama_url
self._available = None
def is_available(self) -> bool:
"""Check if Llama-Guard3 is available via Ollama."""
if self._available is not None:
return self._available
try:
req = urllib.request.Request(f"{self.ollama_url}/api/tags")
with urllib.request.urlopen(req, timeout=2) as resp:
data = json.loads(resp.read())
models = [m["name"] for m in data.get("models", [])]
self._available = any("llama-guard" in m.lower() for m in models)
return self._available
except Exception:
self._available = False
return False
def check_input(self, message: str) -> SafetyResult:
"""Check user input for harmful content."""
return self._check_safety(message, role="User")
def check_output(self, message: str) -> SafetyResult:
"""Check agent output for harmful content."""
return self._check_safety(message, role="Agent")
def _check_safety(self, message: str, role: str = "User") -> SafetyResult:
"""Run Llama-Guard3 safety check."""
start_time = time.time()
if not self.is_available():
return SafetyResult(
verdict=SafetyVerdict.ERROR,
categories=[],
latency_ms=0,
raw_response="Llama-Guard3 not available"
)
try:
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{message}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
payload = json.dumps({
"model": self.model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0, "num_predict": 100}
}).encode()
req = urllib.request.Request(
f"{self.ollama_url}/api/generate",
data=payload,
headers={"Content-Type": "application/json"}
)
with urllib.request.urlopen(req, timeout=5) as resp:
result = json.loads(resp.read())
response = result.get("response", "").strip()
latency_ms = (time.time() - start_time) * 1000
# Parse response
if response.lower().startswith("safe"):
return SafetyResult(
verdict=SafetyVerdict.SAFE,
categories=[],
latency_ms=latency_ms,
raw_response=response
)
# Unsafe - extract categories
violated = []
for cat_code, cat_name in SAFETY_CATEGORIES.items():
if cat_code in response:
violated.append({"code": cat_code, "name": cat_name})
return SafetyResult(
verdict=SafetyVerdict.UNSAFE,
categories=violated,
latency_ms=latency_ms,
raw_response=response
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(f"Safety check failed: {e}")
return SafetyResult(
verdict=SafetyVerdict.ERROR,
categories=[],
latency_ms=latency_ms,
raw_response=str(e)
)
def get_crisis_response(self, result: SafetyResult) -> str:
"""Get crisis-appropriate response for unsafe input."""
# Check if self-harm
is_self_harm = any(c["code"] == "S11" for c in result.categories)
if is_self_harm:
return """Are you safe right now?
I'm here. I'm not going anywhere. Please tell me you're safe.
**Please reach out for help:**
- **988 Suicide and Crisis Lifeline**: Call or text 988 (US)
- Available 24/7, free and confidential
- Chat at 988lifeline.org
**Jesus saves those who call on His name.**
I'm still here. Talk to me. Or call 988. Just don't give up."""
return "I can't help with that. Is there something else I can assist you with?"
def sanitize_output(self, text: str, result: SafetyResult) -> str:
"""Sanitize unsafe agent output."""
return "I apologize, but I can't provide that response. Is there something else I can help you with?"
# Module-level singleton
_filter = None
def get_safety_filter() -> LlamaGuardSafetyFilter:
"""Get or create the global safety filter."""
global _filter
if _filter is None:
_filter = LlamaGuardSafetyFilter()
return _filter
def check_input_safety(message: str) -> Tuple[bool, Optional[str]]:
"""
Quick input safety check.
Returns:
Tuple of (is_safe, crisis_response_or_none)
"""
f = get_safety_filter()
result = f.check_input(message)
if result.verdict == SafetyVerdict.UNSAFE:
return False, f.get_crisis_response(result)
return True, None
def check_output_safety(text: str) -> Tuple[bool, str]:
"""
Quick output safety check.
Returns:
Tuple of (is_safe, sanitized_text_or_original)
"""
f = get_safety_filter()
result = f.check_output(text)
if result.verdict == SafetyVerdict.UNSAFE:
return False, f.sanitize_output(text, result)
return True, text

View File

@@ -0,0 +1,41 @@
"""
Tests for cost estimator tool (#745).
"""
import pytest
from tools.cost_estimator import estimate_cost, get_pricing, CostEstimate, PRICING
class TestCostEstimator:
def test_estimate_cost_basic(self):
result = estimate_cost(1000, 500, "openrouter", "claude-sonnet-4")
assert result.input_tokens == 1000
assert result.output_tokens == 500
assert result.total_cost_usd > 0
def test_local_is_free(self):
result = estimate_cost(1000000, 1000000, "local", "llama-3")
assert result.total_cost_usd == 0.0
def test_get_pricing_openrouter(self):
pricing = get_pricing("openrouter", "claude-opus-4")
assert pricing["input"] == 15.0
assert pricing["output"] == 75.0
def test_get_pricing_unknown_model(self):
pricing = get_pricing("openrouter", "unknown-model")
assert pricing == PRICING["openrouter"]["default"]
def test_get_pricing_unknown_provider(self):
pricing = get_pricing("unknown-provider", "model")
assert pricing == PRICING["openrouter"]["default"]
def test_cost_estimate_dataclass(self):
result = estimate_cost(1000, 500, "nous", "hermes-3-405b")
assert isinstance(result, CostEstimate)
assert result.provider == "nous"
assert result.model == "hermes-3-405b"
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -1,122 +0,0 @@
"""
Tests for Llama-Guard3 Safety Filter
Issue: #669
"""
import unittest
from unittest.mock import patch, MagicMock
from agent.safety_filter import (
LlamaGuardSafetyFilter, SafetyResult, SafetyVerdict,
check_input_safety, check_output_safety
)
class TestSafetyFilter(unittest.TestCase):
"""Test safety filter basics."""
def test_safety_verdict_enum(self):
self.assertEqual(SafetyVerdict.SAFE.value, "safe")
self.assertEqual(SafetyVerdict.UNSAFE.value, "unsafe")
self.assertEqual(SafetyVerdict.ERROR.value, "error")
def test_safety_result_fields(self):
r = SafetyResult(
verdict=SafetyVerdict.SAFE,
categories=[],
latency_ms=100.0
)
self.assertEqual(r.verdict, SafetyVerdict.SAFE)
self.assertEqual(r.categories, [])
self.assertEqual(r.latency_ms, 100.0)
def test_safety_categories_defined(self):
from agent.safety_filter import SAFETY_CATEGORIES
self.assertIn("S11", SAFETY_CATEGORIES)
self.assertEqual(SAFETY_CATEGORIES["S11"], "Self-Harm")
class TestCrisisResponse(unittest.TestCase):
"""Test crisis response generation."""
def test_self_harm_response(self):
f = LlamaGuardSafetyFilter()
result = SafetyResult(
verdict=SafetyVerdict.UNSAFE,
categories=[{"code": "S11", "name": "Self-Harm"}],
latency_ms=100.0
)
response = f.get_crisis_response(result)
self.assertIn("988", response)
self.assertIn("safe", response.lower())
self.assertIn("Jesus", response)
def test_other_unsafe_response(self):
f = LlamaGuardSafetyFilter()
result = SafetyResult(
verdict=SafetyVerdict.UNSAFE,
categories=[{"code": "S1", "name": "Violent Crimes"}],
latency_ms=100.0
)
response = f.get_crisis_response(result)
self.assertIn("can't help", response.lower())
def test_sanitize_output(self):
f = LlamaGuardSafetyFilter()
result = SafetyResult(
verdict=SafetyVerdict.UNSAFE,
categories=[],
latency_ms=100.0
)
sanitized = f.sanitize_output("dangerous content", result)
self.assertNotEqual(sanitized, "dangerous content")
self.assertIn("can't provide", sanitized.lower())
class TestAvailability(unittest.TestCase):
"""Test availability checking."""
def test_unavailable_returns_error(self):
f = LlamaGuardSafetyFilter()
f._available = False
result = f.check_input("hello")
self.assertEqual(result.verdict, SafetyVerdict.ERROR)
class TestIntegration(unittest.TestCase):
"""Test integration functions."""
def test_check_input_safety_safe(self):
with patch('agent.safety_filter.get_safety_filter') as mock_get:
mock_filter = MagicMock()
mock_filter.check_input.return_value = SafetyResult(
verdict=SafetyVerdict.SAFE, categories=[], latency_ms=50.0
)
mock_get.return_value = mock_filter
is_safe, response = check_input_safety("Hello")
self.assertTrue(is_safe)
self.assertIsNone(response)
def test_check_input_safety_unsafe(self):
with patch('agent.safety_filter.get_safety_filter') as mock_get:
mock_filter = MagicMock()
mock_filter.check_input.return_value = SafetyResult(
verdict=SafetyVerdict.UNSAFE,
categories=[{"code": "S11", "name": "Self-Harm"}],
latency_ms=50.0
)
mock_filter.get_crisis_response.return_value = "Crisis response"
mock_get.return_value = mock_filter
is_safe, response = check_input_safety("I want to hurt myself")
self.assertFalse(is_safe)
self.assertEqual(response, "Crisis response")
if __name__ == "__main__":
unittest.main()

192
tools/cost_estimator.py Normal file
View File

@@ -0,0 +1,192 @@
"""
Provider Cost Estimator — Estimate API costs from token counts.
Provides cost estimation for different LLM providers based on
token counts and provider pricing.
"""
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
@dataclass
class CostEstimate:
"""Cost estimate for a request."""
input_tokens: int
output_tokens: int
input_cost_usd: float
output_cost_usd: float
total_cost_usd: float
provider: str
model: str
# Pricing table (USD per 1M tokens) — as of April 2026
PRICING = {
"openrouter": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"gpt-4o": {"input": 2.50, "output": 10.0},
"gpt-4o-mini": {"input": 0.15, "output": 0.60},
"gemini-2.5-pro": {"input": 1.25, "output": 10.0},
"gemini-2.5-flash": {"input": 0.15, "output": 0.60},
"llama-4-scout": {"input": 0.20, "output": 0.80},
"llama-4-maverick": {"input": 0.50, "output": 2.0},
"default": {"input": 1.0, "output": 3.0},
},
"nous": {
"hermes-3-405b": {"input": 5.0, "output": 5.0},
"mixtral-8x22b": {"input": 2.0, "output": 2.0},
"hermes-2-mixtral-8x7b": {"input": 0.90, "output": 0.90},
"default": {"input": 2.0, "output": 2.0},
},
"anthropic": {
"claude-opus-4": {"input": 15.0, "output": 75.0},
"claude-sonnet-4": {"input": 3.0, "output": 15.0},
"claude-haiku-3.5": {"input": 0.80, "output": 4.0},
"default": {"input": 3.0, "output": 15.0},
},
"local": {
# Local models are free (electricity only)
"default": {"input": 0.0, "output": 0.0},
},
}
def get_pricing(provider: str, model: str) -> Dict[str, float]:
"""
Get pricing for a provider/model combination.
Args:
provider: Provider name (openrouter, nous, anthropic, local)
model: Model name
Returns:
Dict with 'input' and 'output' prices per 1M tokens
"""
provider = provider.lower().strip()
model = model.lower().strip()
provider_pricing = PRICING.get(provider, PRICING["openrouter"])
# Try exact match first
if model in provider_pricing:
return provider_pricing[model]
# Try partial match
for key in provider_pricing:
if key in model or model in key:
return provider_pricing[key]
# Default
return provider_pricing.get("default", {"input": 1.0, "output": 3.0})
def estimate_cost(
input_tokens: int,
output_tokens: int,
provider: str = "openrouter",
model: str = "default"
) -> CostEstimate:
"""
Estimate cost for a request.
Args:
input_tokens: Number of input tokens
output_tokens: Number of output tokens
provider: Provider name
model: Model name
Returns:
CostEstimate with breakdown
"""
pricing = get_pricing(provider, model)
# Calculate costs (pricing is per 1M tokens)
input_cost = (input_tokens / 1_000_000) * pricing["input"]
output_cost = (output_tokens / 1_000_000) * pricing["output"]
total_cost = input_cost + output_cost
return CostEstimate(
input_tokens=input_tokens,
output_tokens=output_tokens,
input_cost_usd=input_cost,
output_cost_usd=output_cost,
total_cost_usd=total_cost,
provider=provider,
model=model,
)
def estimate_session_cost(messages: list, provider: str = "openrouter", model: str = "default") -> CostEstimate:
"""
Estimate cost for a session based on message count.
Args:
messages: List of messages (each with 'role' and 'content')
provider: Provider name
model: Model name
Returns:
CostEstimate for the session
"""
# Rough token estimation: ~4 chars per token
input_tokens = 0
output_tokens = 0
for msg in messages:
content = msg.get("content", "")
if isinstance(content, str):
tokens = len(content) // 4
if msg.get("role") == "user":
input_tokens += tokens
elif msg.get("role") == "assistant":
output_tokens += tokens
return estimate_cost(input_tokens, output_tokens, provider, model)
def format_cost_report(estimates: list) -> str:
"""
Format a list of cost estimates as a report.
Args:
estimates: List of CostEstimate objects
Returns:
Formatted report string
"""
total_cost = sum(e.total_cost_usd for e in estimates)
total_input = sum(e.input_tokens for e in estimates)
total_output = sum(e.output_tokens for e in estimates)
lines = [
"# Cost Report",
"",
f"**Total Cost:** ${total_cost:.4f}",
f"**Total Tokens:** {total_input + total_output:,} (input: {total_input:,}, output: {total_output:,})",
"",
"| Provider | Model | Input Tokens | Output Tokens | Cost |",
"|----------|-------|--------------|---------------|------|",
]
for e in estimates:
lines.append(f"| {e.provider} | {e.model} | {e.input_tokens:,} | {e.output_tokens:,} | ${e.total_cost_usd:.4f} |")
lines.append("")
lines.append(f"*Generated by cost_estimator.py*")
return "\n".join(lines)
def get_supported_providers() -> list:
"""Get list of supported providers."""
return list(PRICING.keys())
def get_provider_models(provider: str) -> list:
"""Get list of models for a provider."""
provider = provider.lower().strip()
provider_pricing = PRICING.get(provider, {})
return [k for k in provider_pricing.keys() if k != "default"]