Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8141bf8ba3 | ||
|
|
892c4ab70a |
@@ -1396,8 +1396,6 @@ def normalize_anthropic_response(
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
"model_context_window_exceeded": "length",
|
||||
}
|
||||
finish_reason = stop_reason_map.get(response.stop_reason, "stop")
|
||||
|
||||
@@ -1411,42 +1409,3 @@ def normalize_anthropic_response(
|
||||
),
|
||||
finish_reason,
|
||||
)
|
||||
|
||||
|
||||
def normalize_anthropic_response_v2(
|
||||
response,
|
||||
strip_tool_prefix: bool = False,
|
||||
) -> "NormalizedResponse":
|
||||
"""Normalize Anthropic response to NormalizedResponse.
|
||||
|
||||
Wraps the existing normalize_anthropic_response() and maps its output
|
||||
to the shared transport types. This allows incremental migration
|
||||
without disturbing the legacy call sites.
|
||||
"""
|
||||
from agent.transports.types import NormalizedResponse, build_tool_call
|
||||
|
||||
assistant_msg, finish_reason = normalize_anthropic_response(response, strip_tool_prefix)
|
||||
|
||||
tool_calls = None
|
||||
if assistant_msg.tool_calls:
|
||||
tool_calls = [
|
||||
build_tool_call(
|
||||
id=tc.id,
|
||||
name=tc.function.name,
|
||||
arguments=tc.function.arguments,
|
||||
)
|
||||
for tc in assistant_msg.tool_calls
|
||||
]
|
||||
|
||||
provider_data = {}
|
||||
if getattr(assistant_msg, "reasoning_details", None):
|
||||
provider_data["reasoning_details"] = assistant_msg.reasoning_details
|
||||
|
||||
return NormalizedResponse(
|
||||
content=assistant_msg.content,
|
||||
tool_calls=tool_calls,
|
||||
finish_reason=finish_reason,
|
||||
reasoning=getattr(assistant_msg, "reasoning", None),
|
||||
usage=None,
|
||||
provider_data=provider_data or None,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from agent.telemetry_logger import log_token_usage\n"""Shared auxiliary client router for side tasks.
|
||||
"""Shared auxiliary client router for side tasks.
|
||||
|
||||
Provides a single resolution chain so every consumer (context compression,
|
||||
session search, web extraction, vision analysis, browser vision) picks up
|
||||
@@ -38,6 +38,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from agent.telemetry_logger import log_token_usage
|
||||
import time
|
||||
from pathlib import Path # noqa: F401 — used by test mocks
|
||||
from types import SimpleNamespace
|
||||
@@ -122,6 +123,16 @@ _OR_HEADERS = {
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
|
||||
# Vercel AI Gateway app attribution headers. HTTP-Referer maps to
|
||||
# referrerUrl and X-Title maps to appName in the gateway analytics.
|
||||
from hermes_cli import __version__ as _HERMES_VERSION
|
||||
|
||||
_AI_GATEWAY_HEADERS = {
|
||||
"HTTP-Referer": "https://hermes-agent.nousresearch.com",
|
||||
"X-Title": "Hermes Agent",
|
||||
"User-Agent": f"HermesAgent/{_HERMES_VERSION}",
|
||||
}
|
||||
|
||||
# Nous Portal extra_body for product attribution.
|
||||
# Callers should pass this as extra_body in chat.completions.create()
|
||||
# when the auxiliary client is backed by Nous Portal.
|
||||
@@ -396,7 +407,8 @@ class _CodexCompletionsAdapter:
|
||||
prompt_tokens=getattr(resp_usage, "input_tokens", 0),
|
||||
completion_tokens=getattr(resp_usage, "output_tokens", 0),
|
||||
total_tokens=getattr(resp_usage, "total_tokens", 0),
|
||||
)\n log_token_usage(usage.prompt_tokens, usage.completion_tokens, model)
|
||||
)
|
||||
log_token_usage(usage.prompt_tokens, usage.completion_tokens, model)
|
||||
except Exception as exc:
|
||||
logger.debug("Codex auxiliary Responses API call failed: %s", exc)
|
||||
raise
|
||||
@@ -529,7 +541,8 @@ class _AnthropicCompletionsAdapter:
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)\n log_token_usage(usage.prompt_tokens, usage.completion_tokens, model)
|
||||
)
|
||||
log_token_usage(usage.prompt_tokens, usage.completion_tokens, model)
|
||||
|
||||
choice = SimpleNamespace(
|
||||
index=0,
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
"""Transport layer types and registry for provider response normalization.
|
||||
|
||||
Usage:
|
||||
from agent.transports import get_transport
|
||||
transport = get_transport("anthropic_messages")
|
||||
result = transport.normalize_response(raw_response)
|
||||
"""
|
||||
|
||||
from agent.transports.types import ( # noqa: F401
|
||||
NormalizedResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
build_tool_call,
|
||||
map_finish_reason,
|
||||
)
|
||||
|
||||
_REGISTRY: dict = {}
|
||||
|
||||
|
||||
def register_transport(api_mode: str, transport_cls: type) -> None:
|
||||
"""Register a transport class for an api_mode string."""
|
||||
_REGISTRY[api_mode] = transport_cls
|
||||
|
||||
|
||||
def get_transport(api_mode: str):
|
||||
"""Get a transport instance for the given api_mode.
|
||||
|
||||
Returns None if no transport is registered for this api_mode.
|
||||
This allows gradual migration — call sites can check for None
|
||||
and fall back to the legacy code path.
|
||||
"""
|
||||
if not _REGISTRY:
|
||||
_discover_transports()
|
||||
cls = _REGISTRY.get(api_mode)
|
||||
if cls is None:
|
||||
return None
|
||||
return cls()
|
||||
|
||||
|
||||
def _discover_transports() -> None:
|
||||
"""Import all transport modules to trigger auto-registration."""
|
||||
try:
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.codex # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.chat_completions # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import agent.transports.bedrock # noqa: F401
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -1,95 +0,0 @@
|
||||
"""Anthropic Messages API transport.
|
||||
|
||||
Delegates to the existing adapter functions in agent/anthropic_adapter.py.
|
||||
This transport owns format conversion and normalization — NOT client lifecycle.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.transports.base import ProviderTransport
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class AnthropicTransport(ProviderTransport):
|
||||
"""Transport for api_mode='anthropic_messages'."""
|
||||
|
||||
@property
|
||||
def api_mode(self) -> str:
|
||||
return "anthropic_messages"
|
||||
|
||||
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
||||
from agent.anthropic_adapter import convert_messages_to_anthropic
|
||||
|
||||
base_url = kwargs.get("base_url")
|
||||
return convert_messages_to_anthropic(messages, base_url=base_url)
|
||||
|
||||
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
|
||||
from agent.anthropic_adapter import convert_tools_to_anthropic
|
||||
|
||||
return convert_tools_to_anthropic(tools)
|
||||
|
||||
def build_kwargs(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
from agent.anthropic_adapter import build_anthropic_kwargs
|
||||
|
||||
return build_anthropic_kwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
max_tokens=params.get("max_tokens", 16384),
|
||||
reasoning_config=params.get("reasoning_config"),
|
||||
tool_choice=params.get("tool_choice"),
|
||||
is_oauth=params.get("is_oauth", False),
|
||||
preserve_dots=params.get("preserve_dots", False),
|
||||
context_length=params.get("context_length"),
|
||||
base_url=params.get("base_url"),
|
||||
fast_mode=params.get("fast_mode", False),
|
||||
)
|
||||
|
||||
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
|
||||
from agent.anthropic_adapter import normalize_anthropic_response_v2
|
||||
|
||||
strip_tool_prefix = kwargs.get("strip_tool_prefix", False)
|
||||
return normalize_anthropic_response_v2(response, strip_tool_prefix=strip_tool_prefix)
|
||||
|
||||
def validate_response(self, response: Any) -> bool:
|
||||
if response is None:
|
||||
return False
|
||||
content_blocks = getattr(response, "content", None)
|
||||
if not isinstance(content_blocks, list):
|
||||
return False
|
||||
if not content_blocks:
|
||||
return False
|
||||
return True
|
||||
|
||||
def extract_cache_stats(self, response: Any):
|
||||
usage = getattr(response, "usage", None)
|
||||
if usage is None:
|
||||
return None
|
||||
cached = getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
written = getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
if cached or written:
|
||||
return {"cached_tokens": cached, "creation_tokens": written}
|
||||
return None
|
||||
|
||||
_STOP_REASON_MAP = {
|
||||
"end_turn": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
"model_context_window_exceeded": "length",
|
||||
}
|
||||
|
||||
def map_finish_reason(self, raw_reason: str) -> str:
|
||||
return self._STOP_REASON_MAP.get(raw_reason, "stop")
|
||||
|
||||
|
||||
from agent.transports import register_transport # noqa: E402
|
||||
|
||||
register_transport("anthropic_messages", AnthropicTransport)
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Abstract base for provider transports.
|
||||
|
||||
A transport owns the data path for one api_mode:
|
||||
convert_messages → convert_tools → build_kwargs → normalize_response
|
||||
|
||||
It does NOT own: client construction, streaming, credential refresh,
|
||||
prompt caching, interrupt handling, or retry logic. Those stay on AIAgent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class ProviderTransport(ABC):
|
||||
"""Base class for provider-specific format conversion and normalization."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def api_mode(self) -> str:
|
||||
"""The api_mode string this transport handles."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_messages(self, messages: List[Dict[str, Any]], **kwargs) -> Any:
|
||||
"""Convert OpenAI-format messages to provider-native format."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def convert_tools(self, tools: List[Dict[str, Any]]) -> Any:
|
||||
"""Convert OpenAI-format tool definitions to provider-native format."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def build_kwargs(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
**params,
|
||||
) -> Dict[str, Any]:
|
||||
"""Build the complete provider kwargs dict."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def normalize_response(self, response: Any, **kwargs) -> NormalizedResponse:
|
||||
"""Normalize a raw provider response to the shared NormalizedResponse type."""
|
||||
...
|
||||
|
||||
def validate_response(self, response: Any) -> bool:
|
||||
"""Optional structural validation for raw responses."""
|
||||
return True
|
||||
|
||||
def extract_cache_stats(self, response: Any) -> Optional[Dict[str, int]]:
|
||||
"""Optional cache stats extraction."""
|
||||
return None
|
||||
|
||||
def map_finish_reason(self, raw_reason: str) -> str:
|
||||
"""Optional stop-reason mapping. Defaults to passthrough."""
|
||||
return raw_reason
|
||||
@@ -1,58 +0,0 @@
|
||||
"""Shared types for normalized provider responses."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCall:
|
||||
"""A normalized tool call from any provider."""
|
||||
|
||||
id: Optional[str]
|
||||
name: str
|
||||
arguments: str
|
||||
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Usage:
|
||||
"""Token usage from an API response."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
cached_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class NormalizedResponse:
|
||||
"""Normalized API response from any provider."""
|
||||
|
||||
content: Optional[str]
|
||||
tool_calls: Optional[List[ToolCall]]
|
||||
finish_reason: str
|
||||
reasoning: Optional[str] = None
|
||||
usage: Optional[Usage] = None
|
||||
provider_data: Optional[Dict[str, Any]] = field(default=None, repr=False)
|
||||
|
||||
|
||||
def build_tool_call(
|
||||
id: Optional[str],
|
||||
name: str,
|
||||
arguments: Any,
|
||||
**provider_fields: Any,
|
||||
) -> ToolCall:
|
||||
"""Build a ToolCall, auto-serialising dict arguments."""
|
||||
args_str = json.dumps(arguments) if isinstance(arguments, dict) else str(arguments)
|
||||
provider_data = dict(provider_fields) if provider_fields else None
|
||||
return ToolCall(id=id, name=name, arguments=args_str, provider_data=provider_data)
|
||||
|
||||
|
||||
def map_finish_reason(reason: Optional[str], mapping: Dict[str, str]) -> str:
|
||||
"""Translate a provider-specific stop reason to the normalized set."""
|
||||
if reason is None:
|
||||
return "stop"
|
||||
return mapping.get(reason, "stop")
|
||||
@@ -168,7 +168,7 @@ import time as _time
|
||||
from datetime import datetime
|
||||
|
||||
from hermes_cli import __version__, __release_date__
|
||||
from hermes_constants import OPENROUTER_BASE_URL
|
||||
from hermes_constants import AI_GATEWAY_BASE_URL, OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -1112,6 +1112,8 @@ def select_provider_and_model(args=None):
|
||||
# Step 2: Provider-specific setup + model selection
|
||||
if selected_provider == "openrouter":
|
||||
_model_flow_openrouter(config, current_model)
|
||||
elif selected_provider == "ai-gateway":
|
||||
_model_flow_ai_gateway(config, current_model)
|
||||
elif selected_provider == "nous":
|
||||
_model_flow_nous(config, current_model, args=args)
|
||||
elif selected_provider == "openai-codex":
|
||||
@@ -1267,6 +1269,55 @@ def _model_flow_openrouter(config, current_model=""):
|
||||
print("No change.")
|
||||
|
||||
|
||||
def _model_flow_ai_gateway(config, current_model=""):
|
||||
"""Vercel AI Gateway provider: ensure API key, then pick model with pricing."""
|
||||
from hermes_cli.auth import _prompt_model_selection, _save_model_choice, deactivate_provider
|
||||
from hermes_cli.config import get_env_value, save_env_value
|
||||
from hermes_cli.models import ai_gateway_model_ids, get_pricing_for_provider
|
||||
|
||||
api_key = get_env_value("AI_GATEWAY_API_KEY")
|
||||
if not api_key:
|
||||
print("No Vercel AI Gateway API key configured.")
|
||||
print("Create API key here: https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai-gateway&title=AI+Gateway")
|
||||
print("Add a payment method to get $5 in free credits.")
|
||||
print()
|
||||
try:
|
||||
import getpass
|
||||
key = getpass.getpass("AI Gateway API key (or Enter to cancel): ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print()
|
||||
return
|
||||
if not key:
|
||||
print("Cancelled.")
|
||||
return
|
||||
save_env_value("AI_GATEWAY_API_KEY", key)
|
||||
print("API key saved.")
|
||||
print()
|
||||
|
||||
models_list = ai_gateway_model_ids(force_refresh=True)
|
||||
pricing = get_pricing_for_provider("ai-gateway", force_refresh=True)
|
||||
|
||||
selected = _prompt_model_selection(models_list, current_model=current_model, pricing=pricing)
|
||||
if selected:
|
||||
_save_model_choice(selected)
|
||||
|
||||
from hermes_cli.config import load_config, save_config
|
||||
|
||||
cfg = load_config()
|
||||
model = cfg.get("model")
|
||||
if not isinstance(model, dict):
|
||||
model = {"default": model} if model else {}
|
||||
cfg["model"] = model
|
||||
model["provider"] = "ai-gateway"
|
||||
model["base_url"] = AI_GATEWAY_BASE_URL
|
||||
model["api_mode"] = "chat_completions"
|
||||
save_config(cfg)
|
||||
deactivate_provider()
|
||||
print(f"Default model set to: {selected} (via Vercel AI Gateway)")
|
||||
else:
|
||||
print("No change.")
|
||||
|
||||
|
||||
def _model_flow_nous(config, current_model="", args=None):
|
||||
"""Nous Portal provider: ensure logged in, then pick model."""
|
||||
from hermes_cli.auth import (
|
||||
|
||||
@@ -58,6 +58,28 @@ OPENROUTER_MODELS: list[tuple[str, str]] = [
|
||||
|
||||
_openrouter_catalog_cache: list[tuple[str, str]] | None = None
|
||||
|
||||
# Fallback Vercel AI Gateway snapshot used when the live catalog is unavailable.
|
||||
# OSS / open-weight models prioritized first, then closed-source by family.
|
||||
VERCEL_AI_GATEWAY_MODELS: list[tuple[str, str]] = [
|
||||
("moonshotai/kimi-k2.6", "recommended"),
|
||||
("alibaba/qwen3.6-plus", ""),
|
||||
("zai/glm-5.1", ""),
|
||||
("minimax/minimax-m2.7", ""),
|
||||
("anthropic/claude-sonnet-4.6", ""),
|
||||
("anthropic/claude-opus-4.7", ""),
|
||||
("anthropic/claude-opus-4.6", ""),
|
||||
("anthropic/claude-haiku-4.5", ""),
|
||||
("openai/gpt-5.4", ""),
|
||||
("openai/gpt-5.4-mini", ""),
|
||||
("openai/gpt-5.3-codex", ""),
|
||||
("google/gemini-3.1-pro-preview", ""),
|
||||
("google/gemini-3-flash", ""),
|
||||
("google/gemini-3.1-flash-lite-preview", ""),
|
||||
("xai/grok-4.20-reasoning", ""),
|
||||
]
|
||||
|
||||
_ai_gateway_catalog_cache: list[tuple[str, str]] | None = None
|
||||
|
||||
|
||||
def _codex_curated_models() -> list[str]:
|
||||
"""Derive the openai-codex curated list from codex_models.py.
|
||||
@@ -258,18 +280,21 @@ _PROVIDER_MODELS: dict[str, list[str]] = {
|
||||
"minimax-m2.5",
|
||||
],
|
||||
"ai-gateway": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
"moonshotai/kimi-k2.6",
|
||||
"alibaba/qwen3.6-plus",
|
||||
"zai/glm-5.1",
|
||||
"minimax/minimax-m2.7",
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
"anthropic/claude-sonnet-4.5",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"anthropic/claude-opus-4.6",
|
||||
"anthropic/claude-haiku-4.5",
|
||||
"openai/gpt-5",
|
||||
"openai/gpt-4.1",
|
||||
"openai/gpt-4.1-mini",
|
||||
"google/gemini-3-pro-preview",
|
||||
"openai/gpt-5.4",
|
||||
"openai/gpt-5.4-mini",
|
||||
"openai/gpt-5.3-codex",
|
||||
"google/gemini-3.1-pro-preview",
|
||||
"google/gemini-3-flash",
|
||||
"google/gemini-2.5-pro",
|
||||
"google/gemini-2.5-flash",
|
||||
"deepseek/deepseek-v3.2",
|
||||
"google/gemini-3.1-flash-lite-preview",
|
||||
"xai/grok-4.20-reasoning",
|
||||
],
|
||||
"kilocode": [
|
||||
"anthropic/claude-opus-4.6",
|
||||
@@ -516,6 +541,7 @@ class ProviderEntry(NamedTuple):
|
||||
CANONICAL_PROVIDERS: list[ProviderEntry] = [
|
||||
ProviderEntry("nous", "Nous Portal", "Nous Portal (Nous Research subscription)"),
|
||||
ProviderEntry("openrouter", "OpenRouter", "OpenRouter (100+ models, pay-per-use)"),
|
||||
ProviderEntry("ai-gateway", "Vercel AI Gateway", "Vercel AI Gateway (200+ models, $5 free credit, no markup)"),
|
||||
ProviderEntry("anthropic", "Anthropic", "Anthropic (Claude models — API key or Claude Code)"),
|
||||
ProviderEntry("openai-codex", "OpenAI Codex", "OpenAI Codex"),
|
||||
ProviderEntry("xiaomi", "Xiaomi MiMo", "Xiaomi MiMo (MiMo-V2 models — pro, omni, flash)"),
|
||||
@@ -536,7 +562,6 @@ CANONICAL_PROVIDERS: list[ProviderEntry] = [
|
||||
ProviderEntry("kilocode", "Kilo Code", "Kilo Code (Kilo Gateway API)"),
|
||||
ProviderEntry("opencode-zen", "OpenCode Zen", "OpenCode Zen (35+ curated models, pay-as-you-go)"),
|
||||
ProviderEntry("opencode-go", "OpenCode Go", "OpenCode Go (open models, $10/month subscription)"),
|
||||
ProviderEntry("ai-gateway", "Vercel AI Gateway", "Vercel AI Gateway (200+ models, pay-per-use)"),
|
||||
]
|
||||
|
||||
# Derived dicts — used throughout the codebase
|
||||
@@ -679,6 +704,90 @@ def model_ids(*, force_refresh: bool = False) -> list[str]:
|
||||
|
||||
|
||||
|
||||
def _ai_gateway_model_is_free(pricing: Any) -> bool:
|
||||
"""Return True if an AI Gateway model has $0 input AND output pricing."""
|
||||
if not isinstance(pricing, dict):
|
||||
return False
|
||||
try:
|
||||
return float(pricing.get("input", "0")) == 0 and float(pricing.get("output", "0")) == 0
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
def fetch_ai_gateway_models(
|
||||
timeout: float = 8.0,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> list[tuple[str, str]]:
|
||||
"""Return the curated AI Gateway picker list, refreshed from the live catalog when possible."""
|
||||
global _ai_gateway_catalog_cache
|
||||
|
||||
if _ai_gateway_catalog_cache is not None and not force_refresh:
|
||||
return list(_ai_gateway_catalog_cache)
|
||||
|
||||
from hermes_constants import AI_GATEWAY_BASE_URL
|
||||
|
||||
fallback = list(VERCEL_AI_GATEWAY_MODELS)
|
||||
preferred_ids = [mid for mid, _ in fallback]
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{AI_GATEWAY_BASE_URL.rstrip('/')}/models",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
payload = json.loads(resp.read().decode())
|
||||
except Exception:
|
||||
return list(_ai_gateway_catalog_cache or fallback)
|
||||
|
||||
live_items = payload.get("data", [])
|
||||
if not isinstance(live_items, list):
|
||||
return list(_ai_gateway_catalog_cache or fallback)
|
||||
|
||||
live_by_id: dict[str, dict[str, Any]] = {}
|
||||
for item in live_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mid = str(item.get("id") or "").strip()
|
||||
if not mid:
|
||||
continue
|
||||
live_by_id[mid] = item
|
||||
|
||||
curated: list[tuple[str, str]] = []
|
||||
for preferred_id in preferred_ids:
|
||||
live_item = live_by_id.get(preferred_id)
|
||||
if live_item is None:
|
||||
continue
|
||||
desc = "free" if _ai_gateway_model_is_free(live_item.get("pricing")) else ""
|
||||
curated.append((preferred_id, desc))
|
||||
|
||||
if not curated:
|
||||
return list(_ai_gateway_catalog_cache or fallback)
|
||||
|
||||
free_moonshot = next(
|
||||
(
|
||||
mid
|
||||
for mid, item in live_by_id.items()
|
||||
if mid.startswith("moonshotai/") and _ai_gateway_model_is_free(item.get("pricing"))
|
||||
),
|
||||
None,
|
||||
)
|
||||
if free_moonshot:
|
||||
curated = [(mid, desc) for mid, desc in curated if mid != free_moonshot]
|
||||
curated.insert(0, (free_moonshot, "recommended"))
|
||||
else:
|
||||
first_id, _ = curated[0]
|
||||
curated[0] = (first_id, "recommended")
|
||||
|
||||
_ai_gateway_catalog_cache = curated
|
||||
return list(curated)
|
||||
|
||||
|
||||
def ai_gateway_model_ids(*, force_refresh: bool = False) -> list[str]:
|
||||
"""Return just the AI Gateway model-id strings."""
|
||||
return [mid for mid, _ in fetch_ai_gateway_models(force_refresh=force_refresh)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pricing helpers — fetch live pricing from OpenRouter-compatible /v1/models
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -821,6 +930,51 @@ def fetch_models_with_pricing(
|
||||
return result
|
||||
|
||||
|
||||
def fetch_ai_gateway_pricing(
|
||||
timeout: float = 8.0,
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> dict[str, dict[str, str]]:
|
||||
"""Fetch Vercel AI Gateway /v1/models and return Hermes-shaped pricing."""
|
||||
from hermes_constants import AI_GATEWAY_BASE_URL
|
||||
|
||||
cache_key = AI_GATEWAY_BASE_URL.rstrip("/")
|
||||
if not force_refresh and cache_key in _pricing_cache:
|
||||
return _pricing_cache[cache_key]
|
||||
|
||||
try:
|
||||
req = urllib.request.Request(
|
||||
f"{cache_key}/models",
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
||||
payload = json.loads(resp.read().decode())
|
||||
except Exception:
|
||||
_pricing_cache[cache_key] = {}
|
||||
return {}
|
||||
|
||||
result: dict[str, dict[str, str]] = {}
|
||||
for item in payload.get("data", []):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
mid = item.get("id")
|
||||
pricing = item.get("pricing")
|
||||
if not (mid and isinstance(pricing, dict)):
|
||||
continue
|
||||
entry: dict[str, str] = {
|
||||
"prompt": str(pricing.get("input", "")),
|
||||
"completion": str(pricing.get("output", "")),
|
||||
}
|
||||
if pricing.get("input_cache_read"):
|
||||
entry["input_cache_read"] = str(pricing["input_cache_read"])
|
||||
if pricing.get("input_cache_write"):
|
||||
entry["input_cache_write"] = str(pricing["input_cache_write"])
|
||||
result[mid] = entry
|
||||
|
||||
_pricing_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_openrouter_api_key() -> str:
|
||||
"""Best-effort OpenRouter API key for pricing fetch."""
|
||||
return os.getenv("OPENROUTER_API_KEY", "").strip()
|
||||
@@ -839,7 +993,7 @@ def _resolve_nous_pricing_credentials() -> tuple[str, str]:
|
||||
|
||||
|
||||
def get_pricing_for_provider(provider: str, *, force_refresh: bool = False) -> dict[str, dict[str, str]]:
|
||||
"""Return live pricing for providers that support it (openrouter, nous)."""
|
||||
"""Return live pricing for providers that support it (openrouter, ai-gateway, nous)."""
|
||||
normalized = normalize_provider(provider)
|
||||
if normalized == "openrouter":
|
||||
return fetch_models_with_pricing(
|
||||
@@ -847,11 +1001,11 @@ def get_pricing_for_provider(provider: str, *, force_refresh: bool = False) -> d
|
||||
base_url="https://openrouter.ai/api",
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
if normalized == "ai-gateway":
|
||||
return fetch_ai_gateway_pricing(force_refresh=force_refresh)
|
||||
if normalized == "nous":
|
||||
api_key, base_url = _resolve_nous_pricing_credentials()
|
||||
if base_url:
|
||||
# Nous base_url typically looks like https://inference-api.nousresearch.com/v1
|
||||
# We need the part before /v1 for our fetch function
|
||||
stripped = base_url.rstrip("/")
|
||||
if stripped.endswith("/v1"):
|
||||
stripped = stripped[:-3]
|
||||
@@ -1253,9 +1407,7 @@ def provider_model_ids(provider: Optional[str], *, force_refresh: bool = False)
|
||||
if live:
|
||||
return live
|
||||
if normalized == "ai-gateway":
|
||||
live = _fetch_ai_gateway_models()
|
||||
if live:
|
||||
return live
|
||||
return ai_gateway_model_ids()
|
||||
if normalized == "custom":
|
||||
base_url = _get_custom_base_url()
|
||||
if base_url:
|
||||
|
||||
@@ -908,6 +908,10 @@ class AIAgent:
|
||||
"X-OpenRouter-Title": "Hermes Agent",
|
||||
"X-OpenRouter-Categories": "productivity,cli-agent",
|
||||
}
|
||||
elif "ai-gateway.vercel.sh" in effective_base.lower():
|
||||
from agent.auxiliary_client import _AI_GATEWAY_HEADERS
|
||||
|
||||
client_kwargs["default_headers"] = dict(_AI_GATEWAY_HEADERS)
|
||||
elif "api.githubcopilot.com" in effective_base.lower():
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
@@ -4667,11 +4671,13 @@ class AIAgent:
|
||||
return True
|
||||
|
||||
def _apply_client_headers_for_base_url(self, base_url: str) -> None:
|
||||
from agent.auxiliary_client import _OR_HEADERS
|
||||
from agent.auxiliary_client import _AI_GATEWAY_HEADERS, _OR_HEADERS
|
||||
|
||||
normalized = (base_url or "").lower()
|
||||
if "openrouter" in normalized:
|
||||
self._client_kwargs["default_headers"] = dict(_OR_HEADERS)
|
||||
elif "ai-gateway.vercel.sh" in normalized:
|
||||
self._client_kwargs["default_headers"] = dict(_AI_GATEWAY_HEADERS)
|
||||
elif "api.githubcopilot.com" in normalized:
|
||||
from hermes_cli.models import copilot_default_headers
|
||||
|
||||
|
||||
@@ -1,213 +0,0 @@
|
||||
"""Regression tests: normalize_anthropic_response_v2 vs v1.
|
||||
|
||||
Constructs mock Anthropic responses and asserts that the v2 function
|
||||
(returning NormalizedResponse) produces identical field values to the
|
||||
original v1 function (returning SimpleNamespace + finish_reason).
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.anthropic_adapter import (
|
||||
normalize_anthropic_response,
|
||||
normalize_anthropic_response_v2,
|
||||
)
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
def _text_block(text: str):
|
||||
return SimpleNamespace(type="text", text=text)
|
||||
|
||||
|
||||
def _thinking_block(thinking: str, signature: str = "sig_abc"):
|
||||
return SimpleNamespace(type="thinking", thinking=thinking, signature=signature)
|
||||
|
||||
|
||||
def _tool_use_block(id: str, name: str, input: dict):
|
||||
return SimpleNamespace(type="tool_use", id=id, name=name, input=input)
|
||||
|
||||
|
||||
def _response(content_blocks, stop_reason="end_turn"):
|
||||
return SimpleNamespace(
|
||||
content=content_blocks,
|
||||
stop_reason=stop_reason,
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
)
|
||||
|
||||
|
||||
class TestTextOnly:
|
||||
def setup_method(self):
|
||||
self.resp = _response([_text_block("Hello world")])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_type(self):
|
||||
assert isinstance(self.v2, NormalizedResponse)
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_finish_reason_matches(self):
|
||||
assert self.v2.finish_reason == self.v1_finish
|
||||
|
||||
def test_no_tool_calls(self):
|
||||
assert self.v2.tool_calls is None
|
||||
assert self.v1_msg.tool_calls is None
|
||||
|
||||
def test_no_reasoning(self):
|
||||
assert self.v2.reasoning is None
|
||||
assert self.v1_msg.reasoning is None
|
||||
|
||||
|
||||
class TestWithToolCalls:
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_text_block("I'll check that"),
|
||||
_tool_use_block("toolu_abc", "terminal", {"command": "ls"}),
|
||||
_tool_use_block("toolu_def", "read_file", {"path": "/tmp"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_finish_reason(self):
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
assert self.v1_finish == "tool_calls"
|
||||
|
||||
def test_tool_call_count(self):
|
||||
assert len(self.v2.tool_calls) == 2
|
||||
assert len(self.v1_msg.tool_calls) == 2
|
||||
|
||||
def test_tool_call_ids_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].id == self.v1_msg.tool_calls[i].id
|
||||
|
||||
def test_tool_call_names_match(self):
|
||||
assert self.v2.tool_calls[0].name == "terminal"
|
||||
assert self.v2.tool_calls[1].name == "read_file"
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].name == self.v1_msg.tool_calls[i].function.name
|
||||
|
||||
def test_tool_call_arguments_match(self):
|
||||
for i in range(2):
|
||||
assert self.v2.tool_calls[i].arguments == self.v1_msg.tool_calls[i].function.arguments
|
||||
|
||||
def test_content_preserved(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
assert "check that" in self.v2.content
|
||||
|
||||
|
||||
class TestWithThinking:
|
||||
def setup_method(self):
|
||||
self.resp = _response([
|
||||
_thinking_block("Let me think about this carefully..."),
|
||||
_text_block("The answer is 42."),
|
||||
])
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
assert "think about this" in self.v2.reasoning
|
||||
|
||||
def test_reasoning_details_in_provider_data(self):
|
||||
v1_details = self.v1_msg.reasoning_details
|
||||
v2_details = self.v2.provider_data.get("reasoning_details") if self.v2.provider_data else None
|
||||
assert v1_details is not None
|
||||
assert v2_details is not None
|
||||
assert len(v2_details) == len(v1_details)
|
||||
|
||||
def test_content_excludes_thinking(self):
|
||||
assert self.v2.content == "The answer is 42."
|
||||
|
||||
|
||||
class TestMixed:
|
||||
def setup_method(self):
|
||||
self.resp = _response(
|
||||
[
|
||||
_thinking_block("Planning my approach..."),
|
||||
_text_block("I'll run the command"),
|
||||
_tool_use_block("toolu_xyz", "terminal", {"command": "pwd"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
self.v1_msg, self.v1_finish = normalize_anthropic_response(self.resp)
|
||||
self.v2 = normalize_anthropic_response_v2(self.resp)
|
||||
|
||||
def test_all_fields_present(self):
|
||||
assert self.v2.content is not None
|
||||
assert self.v2.tool_calls is not None
|
||||
assert self.v2.reasoning is not None
|
||||
assert self.v2.finish_reason == "tool_calls"
|
||||
|
||||
def test_content_matches(self):
|
||||
assert self.v2.content == self.v1_msg.content
|
||||
|
||||
def test_reasoning_matches(self):
|
||||
assert self.v2.reasoning == self.v1_msg.reasoning
|
||||
|
||||
def test_tool_call_matches(self):
|
||||
assert self.v2.tool_calls[0].id == self.v1_msg.tool_calls[0].id
|
||||
assert self.v2.tool_calls[0].name == self.v1_msg.tool_calls[0].function.name
|
||||
|
||||
|
||||
class TestStopReasons:
|
||||
@pytest.mark.parametrize("stop_reason,expected", [
|
||||
("end_turn", "stop"),
|
||||
("tool_use", "tool_calls"),
|
||||
("max_tokens", "length"),
|
||||
("stop_sequence", "stop"),
|
||||
("refusal", "content_filter"),
|
||||
("model_context_window_exceeded", "length"),
|
||||
("unknown_future_reason", "stop"),
|
||||
])
|
||||
def test_stop_reason_mapping(self, stop_reason, expected):
|
||||
resp = _response([_text_block("x")], stop_reason=stop_reason)
|
||||
_v1_msg, v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.finish_reason == v1_finish == expected
|
||||
|
||||
|
||||
class TestStripToolPrefix:
|
||||
def test_prefix_stripped(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=True)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=True)
|
||||
assert v1_msg.tool_calls[0].function.name == "terminal"
|
||||
assert v2.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_prefix_kept(self):
|
||||
resp = _response(
|
||||
[_tool_use_block("toolu_1", "mcp_terminal", {"cmd": "ls"})],
|
||||
stop_reason="tool_use",
|
||||
)
|
||||
v1_msg, _ = normalize_anthropic_response(resp, strip_tool_prefix=False)
|
||||
v2 = normalize_anthropic_response_v2(resp, strip_tool_prefix=False)
|
||||
assert v1_msg.tool_calls[0].function.name == "mcp_terminal"
|
||||
assert v2.tool_calls[0].name == "mcp_terminal"
|
||||
|
||||
|
||||
class TestEdgeCases:
|
||||
def test_empty_content_blocks(self):
|
||||
resp = _response([])
|
||||
v1_msg, _v1_finish = normalize_anthropic_response(resp)
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.content == v1_msg.content
|
||||
assert v2.content is None
|
||||
|
||||
def test_no_reasoning_details_means_none_provider_data(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert v2.provider_data is None
|
||||
|
||||
def test_v2_returns_dataclass_not_namespace(self):
|
||||
resp = _response([_text_block("hi")])
|
||||
v2 = normalize_anthropic_response_v2(resp)
|
||||
assert isinstance(v2, NormalizedResponse)
|
||||
assert not isinstance(v2, SimpleNamespace)
|
||||
@@ -1,208 +0,0 @@
|
||||
"""Tests for the transport ABC, registry, and AnthropicTransport."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.transports import _REGISTRY, get_transport, register_transport
|
||||
from agent.transports.base import ProviderTransport
|
||||
from agent.transports.types import NormalizedResponse
|
||||
|
||||
|
||||
class TestProviderTransportABC:
|
||||
def test_cannot_instantiate_abc(self):
|
||||
with pytest.raises(TypeError):
|
||||
ProviderTransport()
|
||||
|
||||
def test_concrete_must_implement_all_abstract(self):
|
||||
class Incomplete(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
Incomplete()
|
||||
|
||||
def test_minimal_concrete(self):
|
||||
class Minimal(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "test_minimal"
|
||||
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {"model": model, "messages": messages}
|
||||
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content="ok", tool_calls=None, finish_reason="stop")
|
||||
|
||||
t = Minimal()
|
||||
assert t.api_mode == "test_minimal"
|
||||
assert t.validate_response(None) is True
|
||||
assert t.extract_cache_stats(None) is None
|
||||
assert t.map_finish_reason("end_turn") == "end_turn"
|
||||
|
||||
|
||||
class TestTransportRegistry:
|
||||
def test_get_unregistered_returns_none(self):
|
||||
assert get_transport("nonexistent_mode") is None
|
||||
|
||||
def test_anthropic_registered_on_import(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
|
||||
t = get_transport("anthropic_messages")
|
||||
assert t is not None
|
||||
assert t.api_mode == "anthropic_messages"
|
||||
|
||||
def test_register_and_get(self):
|
||||
class DummyTransport(ProviderTransport):
|
||||
@property
|
||||
def api_mode(self):
|
||||
return "dummy_test"
|
||||
|
||||
def convert_messages(self, messages, **kw):
|
||||
return messages
|
||||
|
||||
def convert_tools(self, tools):
|
||||
return tools
|
||||
|
||||
def build_kwargs(self, model, messages, tools=None, **params):
|
||||
return {}
|
||||
|
||||
def normalize_response(self, response, **kw):
|
||||
return NormalizedResponse(content=None, tool_calls=None, finish_reason="stop")
|
||||
|
||||
register_transport("dummy_test", DummyTransport)
|
||||
t = get_transport("dummy_test")
|
||||
assert t.api_mode == "dummy_test"
|
||||
_REGISTRY.pop("dummy_test", None)
|
||||
|
||||
|
||||
class TestAnthropicTransport:
|
||||
@pytest.fixture
|
||||
def transport(self):
|
||||
import agent.transports.anthropic # noqa: F401
|
||||
|
||||
return get_transport("anthropic_messages")
|
||||
|
||||
def test_api_mode(self, transport):
|
||||
assert transport.api_mode == "anthropic_messages"
|
||||
|
||||
def test_convert_tools_simple(self, transport):
|
||||
tools = [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"description": "A test",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}]
|
||||
result = transport.convert_tools(tools)
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "test_tool"
|
||||
assert "input_schema" in result[0]
|
||||
|
||||
def test_validate_response_none(self, transport):
|
||||
assert transport.validate_response(None) is False
|
||||
|
||||
def test_validate_response_empty_content(self, transport):
|
||||
r = SimpleNamespace(content=[])
|
||||
assert transport.validate_response(r) is False
|
||||
|
||||
def test_validate_response_valid(self, transport):
|
||||
r = SimpleNamespace(content=[SimpleNamespace(type="text", text="hello")])
|
||||
assert transport.validate_response(r) is True
|
||||
|
||||
def test_map_finish_reason(self, transport):
|
||||
assert transport.map_finish_reason("end_turn") == "stop"
|
||||
assert transport.map_finish_reason("tool_use") == "tool_calls"
|
||||
assert transport.map_finish_reason("max_tokens") == "length"
|
||||
assert transport.map_finish_reason("stop_sequence") == "stop"
|
||||
assert transport.map_finish_reason("refusal") == "content_filter"
|
||||
assert transport.map_finish_reason("model_context_window_exceeded") == "length"
|
||||
assert transport.map_finish_reason("unknown") == "stop"
|
||||
|
||||
def test_extract_cache_stats_none_usage(self, transport):
|
||||
r = SimpleNamespace(usage=None)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_extract_cache_stats_with_cache(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=100, cache_creation_input_tokens=50)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
result = transport.extract_cache_stats(r)
|
||||
assert result == {"cached_tokens": 100, "creation_tokens": 50}
|
||||
|
||||
def test_extract_cache_stats_zero(self, transport):
|
||||
usage = SimpleNamespace(cache_read_input_tokens=0, cache_creation_input_tokens=0)
|
||||
r = SimpleNamespace(usage=usage)
|
||||
assert transport.extract_cache_stats(r) is None
|
||||
|
||||
def test_normalize_response_text(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[SimpleNamespace(type="text", text="Hello world")],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=5),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert isinstance(nr, NormalizedResponse)
|
||||
assert nr.content == "Hello world"
|
||||
assert nr.tool_calls is None or nr.tool_calls == []
|
||||
assert nr.finish_reason == "stop"
|
||||
|
||||
def test_normalize_response_tool_calls(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="tool_use", id="toolu_123", name="terminal", input={"command": "ls"}),
|
||||
],
|
||||
stop_reason="tool_use",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=20),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.finish_reason == "tool_calls"
|
||||
assert len(nr.tool_calls) == 1
|
||||
tc = nr.tool_calls[0]
|
||||
assert tc.name == "terminal"
|
||||
assert tc.id == "toolu_123"
|
||||
assert '"command"' in tc.arguments
|
||||
|
||||
def test_normalize_response_thinking(self, transport):
|
||||
r = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="thinking", thinking="Let me think..."),
|
||||
SimpleNamespace(type="text", text="The answer is 42"),
|
||||
],
|
||||
stop_reason="end_turn",
|
||||
usage=SimpleNamespace(input_tokens=10, output_tokens=15),
|
||||
model="claude-sonnet-4-6",
|
||||
)
|
||||
nr = transport.normalize_response(r)
|
||||
assert nr.content == "The answer is 42"
|
||||
assert nr.reasoning == "Let me think..."
|
||||
|
||||
def test_build_kwargs_returns_dict(self, transport):
|
||||
messages = [{"role": "user", "content": "Hello"}]
|
||||
kw = transport.build_kwargs(
|
||||
model="claude-sonnet-4-6",
|
||||
messages=messages,
|
||||
max_tokens=1024,
|
||||
)
|
||||
assert isinstance(kw, dict)
|
||||
assert "model" in kw
|
||||
assert "max_tokens" in kw
|
||||
assert "messages" in kw
|
||||
|
||||
def test_convert_messages_extracts_system(self, transport):
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hi"},
|
||||
]
|
||||
system, msgs = transport.convert_messages(messages)
|
||||
assert system is not None
|
||||
assert len(msgs) >= 1
|
||||
@@ -1,130 +0,0 @@
|
||||
"""Tests for agent/transports/types.py — dataclass construction + helpers."""
|
||||
|
||||
import json
|
||||
|
||||
from agent.transports.types import (
|
||||
NormalizedResponse,
|
||||
ToolCall,
|
||||
Usage,
|
||||
build_tool_call,
|
||||
map_finish_reason,
|
||||
)
|
||||
|
||||
|
||||
class TestToolCall:
|
||||
def test_basic_construction(self):
|
||||
tc = ToolCall(id="call_abc", name="terminal", arguments='{"cmd": "ls"}')
|
||||
assert tc.id == "call_abc"
|
||||
assert tc.name == "terminal"
|
||||
assert tc.arguments == '{"cmd": "ls"}'
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_none_id(self):
|
||||
tc = ToolCall(id=None, name="read_file", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
def test_provider_data(self):
|
||||
tc = ToolCall(
|
||||
id="call_x",
|
||||
name="t",
|
||||
arguments="{}",
|
||||
provider_data={"call_id": "call_x", "response_item_id": "fc_x"},
|
||||
)
|
||||
assert tc.provider_data["call_id"] == "call_x"
|
||||
assert tc.provider_data["response_item_id"] == "fc_x"
|
||||
|
||||
|
||||
class TestUsage:
|
||||
def test_defaults(self):
|
||||
u = Usage()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
assert u.cached_tokens == 0
|
||||
|
||||
def test_explicit(self):
|
||||
u = Usage(prompt_tokens=100, completion_tokens=50, total_tokens=150, cached_tokens=80)
|
||||
assert u.total_tokens == 150
|
||||
|
||||
|
||||
class TestNormalizedResponse:
|
||||
def test_text_only(self):
|
||||
r = NormalizedResponse(content="hello", tool_calls=None, finish_reason="stop")
|
||||
assert r.content == "hello"
|
||||
assert r.tool_calls is None
|
||||
assert r.finish_reason == "stop"
|
||||
assert r.reasoning is None
|
||||
assert r.usage is None
|
||||
assert r.provider_data is None
|
||||
|
||||
def test_with_tool_calls(self):
|
||||
tcs = [ToolCall(id="call_1", name="terminal", arguments='{"cmd":"pwd"}')]
|
||||
r = NormalizedResponse(content=None, tool_calls=tcs, finish_reason="tool_calls")
|
||||
assert r.finish_reason == "tool_calls"
|
||||
assert len(r.tool_calls) == 1
|
||||
assert r.tool_calls[0].name == "terminal"
|
||||
|
||||
def test_with_reasoning(self):
|
||||
r = NormalizedResponse(
|
||||
content="answer",
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
reasoning="I thought about it",
|
||||
)
|
||||
assert r.reasoning == "I thought about it"
|
||||
|
||||
def test_with_provider_data(self):
|
||||
r = NormalizedResponse(
|
||||
content=None,
|
||||
tool_calls=None,
|
||||
finish_reason="stop",
|
||||
provider_data={"reasoning_details": [{"type": "thinking", "thinking": "hmm"}]},
|
||||
)
|
||||
assert r.provider_data["reasoning_details"][0]["type"] == "thinking"
|
||||
|
||||
|
||||
class TestBuildToolCall:
|
||||
def test_dict_arguments_serialized(self):
|
||||
tc = build_tool_call(id="call_1", name="terminal", arguments={"cmd": "ls"})
|
||||
assert tc.arguments == json.dumps({"cmd": "ls"})
|
||||
assert tc.provider_data is None
|
||||
|
||||
def test_string_arguments_passthrough(self):
|
||||
tc = build_tool_call(id="call_2", name="read_file", arguments='{"path": "/tmp"}')
|
||||
assert tc.arguments == '{"path": "/tmp"}'
|
||||
|
||||
def test_provider_fields(self):
|
||||
tc = build_tool_call(
|
||||
id="call_3",
|
||||
name="terminal",
|
||||
arguments="{}",
|
||||
call_id="call_3",
|
||||
response_item_id="fc_3",
|
||||
)
|
||||
assert tc.provider_data == {"call_id": "call_3", "response_item_id": "fc_3"}
|
||||
|
||||
def test_none_id(self):
|
||||
tc = build_tool_call(id=None, name="t", arguments="{}")
|
||||
assert tc.id is None
|
||||
|
||||
|
||||
class TestMapFinishReason:
|
||||
ANTHROPIC_MAP = {
|
||||
"end_turn": "stop",
|
||||
"tool_use": "tool_calls",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"refusal": "content_filter",
|
||||
}
|
||||
|
||||
def test_known_reason(self):
|
||||
assert map_finish_reason("end_turn", self.ANTHROPIC_MAP) == "stop"
|
||||
assert map_finish_reason("tool_use", self.ANTHROPIC_MAP) == "tool_calls"
|
||||
assert map_finish_reason("max_tokens", self.ANTHROPIC_MAP) == "length"
|
||||
assert map_finish_reason("refusal", self.ANTHROPIC_MAP) == "content_filter"
|
||||
|
||||
def test_unknown_reason_defaults_to_stop(self):
|
||||
assert map_finish_reason("something_new", self.ANTHROPIC_MAP) == "stop"
|
||||
|
||||
def test_none_reason(self):
|
||||
assert map_finish_reason(None, self.ANTHROPIC_MAP) == "stop"
|
||||
222
tests/hermes_cli/test_ai_gateway_models.py
Normal file
222
tests/hermes_cli/test_ai_gateway_models.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""AI Gateway provider UX, live pricing, and model promotion tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from hermes_cli import models as models_module
|
||||
from hermes_cli.models import (
|
||||
CANONICAL_PROVIDERS,
|
||||
VERCEL_AI_GATEWAY_MODELS,
|
||||
_ai_gateway_model_is_free,
|
||||
ai_gateway_model_ids,
|
||||
fetch_ai_gateway_models,
|
||||
fetch_ai_gateway_pricing,
|
||||
get_pricing_for_provider,
|
||||
)
|
||||
|
||||
|
||||
def _mock_urlopen(payload):
|
||||
resp = MagicMock()
|
||||
resp.read.return_value = json.dumps(payload).encode()
|
||||
ctx = MagicMock()
|
||||
ctx.__enter__.return_value = resp
|
||||
ctx.__exit__.return_value = False
|
||||
return ctx
|
||||
|
||||
|
||||
def _reset_caches():
|
||||
models_module._ai_gateway_catalog_cache = None
|
||||
models_module._pricing_cache.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config_home(tmp_path, monkeypatch):
|
||||
home = tmp_path / "hermes"
|
||||
home.mkdir()
|
||||
(home / "config.yaml").write_text("model: some-old-model\n")
|
||||
(home / ".env").write_text("")
|
||||
monkeypatch.setenv("HERMES_HOME", str(home))
|
||||
monkeypatch.delenv("AI_GATEWAY_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AI_GATEWAY_BASE_URL", raising=False)
|
||||
return home
|
||||
|
||||
|
||||
def test_ai_gateway_provider_is_promoted_near_top_of_picker():
|
||||
slugs = [entry.slug for entry in CANONICAL_PROVIDERS]
|
||||
assert "ai-gateway" in slugs[:3]
|
||||
|
||||
|
||||
def test_ai_gateway_pricing_translates_input_output_to_prompt_completion():
|
||||
_reset_caches()
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"id": "moonshotai/kimi-k2.5",
|
||||
"type": "language",
|
||||
"pricing": {
|
||||
"input": "0.0000006",
|
||||
"output": "0.0000025",
|
||||
"input_cache_read": "0.00000015",
|
||||
"input_cache_write": "0.0000006",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_pricing(force_refresh=True)
|
||||
|
||||
entry = result["moonshotai/kimi-k2.5"]
|
||||
assert entry["prompt"] == "0.0000006"
|
||||
assert entry["completion"] == "0.0000025"
|
||||
assert entry["input_cache_read"] == "0.00000015"
|
||||
assert entry["input_cache_write"] == "0.0000006"
|
||||
|
||||
|
||||
def test_get_pricing_for_provider_supports_ai_gateway():
|
||||
_reset_caches()
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"id": "moonshotai/kimi-k2.5",
|
||||
"type": "language",
|
||||
"pricing": {"input": "0.0001", "output": "0.0002"},
|
||||
}
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = get_pricing_for_provider("ai-gateway", force_refresh=True)
|
||||
assert result["moonshotai/kimi-k2.5"] == {"prompt": "0.0001", "completion": "0.0002"}
|
||||
|
||||
|
||||
def test_ai_gateway_pricing_returns_empty_on_fetch_failure():
|
||||
_reset_caches()
|
||||
with patch("urllib.request.urlopen", side_effect=OSError("network down")):
|
||||
result = fetch_ai_gateway_pricing(force_refresh=True)
|
||||
assert result == {}
|
||||
|
||||
|
||||
def test_ai_gateway_pricing_skips_entries_without_pricing_dict():
|
||||
_reset_caches()
|
||||
payload = {
|
||||
"data": [
|
||||
{"id": "x/y", "pricing": None},
|
||||
{"id": "a/b", "pricing": {"input": "0", "output": "0"}},
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_pricing(force_refresh=True)
|
||||
assert "x/y" not in result
|
||||
assert result["a/b"] == {"prompt": "0", "completion": "0"}
|
||||
|
||||
|
||||
def test_ai_gateway_free_detector():
|
||||
assert _ai_gateway_model_is_free({"input": "0", "output": "0"}) is True
|
||||
assert _ai_gateway_model_is_free({"input": "0", "output": "0.01"}) is False
|
||||
assert _ai_gateway_model_is_free({"input": "0.01", "output": "0"}) is False
|
||||
assert _ai_gateway_model_is_free(None) is False
|
||||
assert _ai_gateway_model_is_free({"input": "not a number"}) is False
|
||||
|
||||
|
||||
def test_fetch_ai_gateway_models_filters_against_live_catalog():
|
||||
_reset_caches()
|
||||
preferred = [mid for mid, _ in VERCEL_AI_GATEWAY_MODELS]
|
||||
live_ids = preferred[:3]
|
||||
payload = {
|
||||
"data": [
|
||||
{"id": mid, "pricing": {"input": "0.001", "output": "0.002"}}
|
||||
for mid in live_ids
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_models(force_refresh=True)
|
||||
|
||||
assert [mid for mid, _ in result] == live_ids
|
||||
assert result[0][1] == "recommended"
|
||||
assert ai_gateway_model_ids(force_refresh=False) == live_ids
|
||||
|
||||
|
||||
def test_fetch_ai_gateway_models_tags_free_models():
|
||||
_reset_caches()
|
||||
first_id = VERCEL_AI_GATEWAY_MODELS[0][0]
|
||||
second_id = VERCEL_AI_GATEWAY_MODELS[1][0]
|
||||
payload = {
|
||||
"data": [
|
||||
{"id": first_id, "pricing": {"input": "0.001", "output": "0.002"}},
|
||||
{"id": second_id, "pricing": {"input": "0", "output": "0"}},
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_models(force_refresh=True)
|
||||
|
||||
by_id = dict(result)
|
||||
assert by_id[first_id] == "recommended"
|
||||
assert by_id[second_id] == "free"
|
||||
|
||||
|
||||
def test_free_moonshot_model_auto_promoted_to_top_even_if_not_curated():
|
||||
_reset_caches()
|
||||
first_curated = VERCEL_AI_GATEWAY_MODELS[0][0]
|
||||
unlisted_free_moonshot = "moonshotai/kimi-coder-free-preview"
|
||||
payload = {
|
||||
"data": [
|
||||
{"id": first_curated, "pricing": {"input": "0.001", "output": "0.002"}},
|
||||
{"id": unlisted_free_moonshot, "pricing": {"input": "0", "output": "0"}},
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_models(force_refresh=True)
|
||||
|
||||
assert result[0] == (unlisted_free_moonshot, "recommended")
|
||||
assert any(mid == first_curated for mid, _ in result)
|
||||
|
||||
|
||||
def test_paid_moonshot_does_not_get_auto_promoted():
|
||||
_reset_caches()
|
||||
first_curated = VERCEL_AI_GATEWAY_MODELS[0][0]
|
||||
payload = {
|
||||
"data": [
|
||||
{"id": first_curated, "pricing": {"input": "0.001", "output": "0.002"}},
|
||||
{"id": "moonshotai/some-paid-variant", "pricing": {"input": "0.001", "output": "0.002"}},
|
||||
]
|
||||
}
|
||||
with patch("urllib.request.urlopen", return_value=_mock_urlopen(payload)):
|
||||
result = fetch_ai_gateway_models(force_refresh=True)
|
||||
|
||||
assert result[0][0] == first_curated
|
||||
|
||||
|
||||
def test_fetch_ai_gateway_models_falls_back_on_error():
|
||||
_reset_caches()
|
||||
with patch("urllib.request.urlopen", side_effect=OSError("network")):
|
||||
result = fetch_ai_gateway_models(force_refresh=True)
|
||||
assert result == list(VERCEL_AI_GATEWAY_MODELS)
|
||||
|
||||
|
||||
def test_ai_gateway_setup_flow_shows_deeplink_and_passes_pricing(config_home, monkeypatch, capsys):
|
||||
from hermes_cli.main import _model_flow_ai_gateway
|
||||
from hermes_cli.config import load_config
|
||||
|
||||
pricing = {"moonshotai/kimi-k2.6": {"prompt": "0", "completion": "0"}}
|
||||
monkeypatch.setenv("HERMES_HOME", str(config_home))
|
||||
|
||||
with patch("getpass.getpass", return_value="vercel-key"), \
|
||||
patch("hermes_cli.models.ai_gateway_model_ids", return_value=["moonshotai/kimi-k2.6"]), \
|
||||
patch("hermes_cli.models.get_pricing_for_provider", return_value=pricing), \
|
||||
patch("hermes_cli.auth._prompt_model_selection", return_value="moonshotai/kimi-k2.6") as prompt_selection, \
|
||||
patch("hermes_cli.auth.deactivate_provider"):
|
||||
_model_flow_ai_gateway(load_config(), "")
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai-gateway&title=AI+Gateway" in out
|
||||
assert "free credits" in out.lower()
|
||||
assert prompt_selection.call_args.kwargs["pricing"] == pricing
|
||||
|
||||
import yaml
|
||||
config = yaml.safe_load((config_home / "config.yaml").read_text()) or {}
|
||||
model = config["model"]
|
||||
assert model["provider"] == "ai-gateway"
|
||||
assert model["api_mode"] == "chat_completions"
|
||||
62
tests/run_agent/test_provider_attribution_headers.py
Normal file
62
tests/run_agent/test_provider_attribution_headers.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Attribution default_headers applied per provider via base-URL detection."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from run_agent import AIAgent
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_openrouter_base_url_applies_or_headers(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
agent._apply_client_headers_for_base_url("https://openrouter.ai/api/v1")
|
||||
|
||||
headers = agent._client_kwargs["default_headers"]
|
||||
assert headers["HTTP-Referer"] == "https://hermes-agent.nousresearch.com"
|
||||
assert headers["X-OpenRouter-Title"] == "Hermes Agent"
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_ai_gateway_base_url_applies_attribution_headers(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
|
||||
agent._apply_client_headers_for_base_url("https://ai-gateway.vercel.sh/v1")
|
||||
|
||||
headers = agent._client_kwargs["default_headers"]
|
||||
assert headers["HTTP-Referer"] == "https://hermes-agent.nousresearch.com"
|
||||
assert headers["X-Title"] == "Hermes Agent"
|
||||
assert headers["User-Agent"].startswith("HermesAgent/")
|
||||
|
||||
|
||||
@patch("run_agent.OpenAI")
|
||||
def test_unknown_base_url_clears_default_headers(mock_openai):
|
||||
mock_openai.return_value = MagicMock()
|
||||
agent = AIAgent(
|
||||
api_key="test-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="test/model",
|
||||
quiet_mode=True,
|
||||
skip_context_files=True,
|
||||
skip_memory=True,
|
||||
)
|
||||
agent._client_kwargs["default_headers"] = {"X-Stale": "yes"}
|
||||
|
||||
agent._apply_client_headers_for_base_url("https://api.example.com/v1")
|
||||
|
||||
assert "default_headers" not in agent._client_kwargs
|
||||
Reference in New Issue
Block a user