Follow-up for salvaged PR #5494: - Update model catalog to Gemini 3.x + Gemma 4 (drop deprecated 2.0) - Add list_agentic_models() to models_dev.py with noise filter - Wire models.dev into _model_flow_api_key_provider as primary source (static curated list serves as offline fallback) - Add gemini -> google mapping in PROVIDER_TO_MODELS_DEV - Fix Gemma 4 context lengths to 256K (models.dev values) - Update auxiliary model to gemini-3-flash-preview - Expand tests: 3.x catalog, context lengths, models.dev integration
782 lines
26 KiB
Python
782 lines
26 KiB
Python
"""Models.dev registry integration — primary database for providers and models.
|
|
|
|
Fetches from https://models.dev/api.json — a community-maintained database
|
|
of 4000+ models across 109+ providers. Provides:
|
|
|
|
- **Provider metadata**: name, base URL, env vars, documentation link
|
|
- **Model metadata**: context window, max output, cost/M tokens, capabilities
|
|
(reasoning, tools, vision, PDF, audio), modalities, knowledge cutoff,
|
|
open-weights flag, family grouping, deprecation status
|
|
|
|
Data resolution order (like TypeScript OpenCode):
|
|
1. Bundled snapshot (ships with the package — offline-first)
|
|
2. Disk cache (~/.hermes/models_dev_cache.json)
|
|
3. Network fetch (https://models.dev/api.json)
|
|
4. Background refresh every 60 minutes
|
|
|
|
Other modules should import the dataclasses and query functions from here
|
|
rather than parsing the raw JSON themselves.
|
|
"""
|
|
|
|
import difflib
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
from utils import atomic_json_write
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODELS_DEV_URL = "https://models.dev/api.json"
|
|
_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory
|
|
|
|
# In-memory cache
|
|
_models_dev_cache: Dict[str, Any] = {}
|
|
_models_dev_cache_time: float = 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dataclasses — rich metadata for providers and models
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class ModelInfo:
|
|
"""Full metadata for a single model from models.dev."""
|
|
|
|
id: str
|
|
name: str
|
|
family: str
|
|
provider_id: str # models.dev provider ID (e.g. "anthropic")
|
|
|
|
# Capabilities
|
|
reasoning: bool = False
|
|
tool_call: bool = False
|
|
attachment: bool = False # supports image/file attachments (vision)
|
|
temperature: bool = False
|
|
structured_output: bool = False
|
|
open_weights: bool = False
|
|
|
|
# Modalities
|
|
input_modalities: Tuple[str, ...] = () # ("text", "image", "pdf", ...)
|
|
output_modalities: Tuple[str, ...] = ()
|
|
|
|
# Limits
|
|
context_window: int = 0
|
|
max_output: int = 0
|
|
max_input: Optional[int] = None
|
|
|
|
# Cost (per million tokens, USD)
|
|
cost_input: float = 0.0
|
|
cost_output: float = 0.0
|
|
cost_cache_read: Optional[float] = None
|
|
cost_cache_write: Optional[float] = None
|
|
|
|
# Metadata
|
|
knowledge_cutoff: str = ""
|
|
release_date: str = ""
|
|
status: str = "" # "alpha", "beta", "deprecated", or ""
|
|
interleaved: Any = False # True or {"field": "reasoning_content"}
|
|
|
|
def has_cost_data(self) -> bool:
|
|
return self.cost_input > 0 or self.cost_output > 0
|
|
|
|
def supports_vision(self) -> bool:
|
|
return self.attachment or "image" in self.input_modalities
|
|
|
|
def supports_pdf(self) -> bool:
|
|
return "pdf" in self.input_modalities
|
|
|
|
def supports_audio_input(self) -> bool:
|
|
return "audio" in self.input_modalities
|
|
|
|
def format_cost(self) -> str:
|
|
"""Human-readable cost string, e.g. '$3.00/M in, $15.00/M out'."""
|
|
if not self.has_cost_data():
|
|
return "unknown"
|
|
parts = [f"${self.cost_input:.2f}/M in", f"${self.cost_output:.2f}/M out"]
|
|
if self.cost_cache_read is not None:
|
|
parts.append(f"cache read ${self.cost_cache_read:.2f}/M")
|
|
return ", ".join(parts)
|
|
|
|
def format_capabilities(self) -> str:
|
|
"""Human-readable capabilities, e.g. 'reasoning, tools, vision, PDF'."""
|
|
caps = []
|
|
if self.reasoning:
|
|
caps.append("reasoning")
|
|
if self.tool_call:
|
|
caps.append("tools")
|
|
if self.supports_vision():
|
|
caps.append("vision")
|
|
if self.supports_pdf():
|
|
caps.append("PDF")
|
|
if self.supports_audio_input():
|
|
caps.append("audio")
|
|
if self.structured_output:
|
|
caps.append("structured output")
|
|
if self.open_weights:
|
|
caps.append("open weights")
|
|
return ", ".join(caps) if caps else "basic"
|
|
|
|
|
|
@dataclass
|
|
class ProviderInfo:
|
|
"""Full metadata for a provider from models.dev."""
|
|
|
|
id: str # models.dev provider ID
|
|
name: str # display name
|
|
env: Tuple[str, ...] # env var names for API key
|
|
api: str # base URL
|
|
doc: str = "" # documentation URL
|
|
model_count: int = 0
|
|
|
|
def has_api_url(self) -> bool:
|
|
return bool(self.api)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Provider ID mapping: Hermes ↔ models.dev
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Hermes provider names → models.dev provider IDs
|
|
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
|
"openrouter": "openrouter",
|
|
"anthropic": "anthropic",
|
|
"zai": "zai",
|
|
"kimi-coding": "kimi-for-coding",
|
|
"minimax": "minimax",
|
|
"minimax-cn": "minimax-cn",
|
|
"deepseek": "deepseek",
|
|
"alibaba": "alibaba",
|
|
"copilot": "github-copilot",
|
|
"ai-gateway": "vercel",
|
|
"opencode-zen": "opencode",
|
|
"opencode-go": "opencode-go",
|
|
"kilocode": "kilo",
|
|
"fireworks": "fireworks-ai",
|
|
"huggingface": "huggingface",
|
|
"gemini": "google",
|
|
"google": "google",
|
|
"xai": "xai",
|
|
"nvidia": "nvidia",
|
|
"groq": "groq",
|
|
"mistral": "mistral",
|
|
"togetherai": "togetherai",
|
|
"perplexity": "perplexity",
|
|
"cohere": "cohere",
|
|
}
|
|
|
|
# Reverse mapping: models.dev → Hermes (built lazily)
|
|
_MODELS_DEV_TO_PROVIDER: Optional[Dict[str, str]] = None
|
|
|
|
|
|
def _get_reverse_mapping() -> Dict[str, str]:
|
|
"""Return models.dev ID → Hermes provider ID mapping."""
|
|
global _MODELS_DEV_TO_PROVIDER
|
|
if _MODELS_DEV_TO_PROVIDER is None:
|
|
_MODELS_DEV_TO_PROVIDER = {v: k for k, v in PROVIDER_TO_MODELS_DEV.items()}
|
|
return _MODELS_DEV_TO_PROVIDER
|
|
|
|
|
|
def _get_cache_path() -> Path:
|
|
"""Return path to disk cache file."""
|
|
env_val = os.environ.get("HERMES_HOME", "")
|
|
hermes_home = Path(env_val) if env_val else Path.home() / ".hermes"
|
|
return hermes_home / "models_dev_cache.json"
|
|
|
|
|
|
def _load_disk_cache() -> Dict[str, Any]:
|
|
"""Load models.dev data from disk cache."""
|
|
try:
|
|
cache_path = _get_cache_path()
|
|
if cache_path.exists():
|
|
with open(cache_path, encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.debug("Failed to load models.dev disk cache: %s", e)
|
|
return {}
|
|
|
|
|
|
def _save_disk_cache(data: Dict[str, Any]) -> None:
|
|
"""Save models.dev data to disk cache atomically."""
|
|
try:
|
|
cache_path = _get_cache_path()
|
|
atomic_json_write(cache_path, data, indent=None, separators=(",", ":"))
|
|
except Exception as e:
|
|
logger.debug("Failed to save models.dev disk cache: %s", e)
|
|
|
|
|
|
def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]:
|
|
"""Fetch models.dev registry. In-memory cache (1hr) + disk fallback.
|
|
|
|
Returns the full registry dict keyed by provider ID, or empty dict on failure.
|
|
"""
|
|
global _models_dev_cache, _models_dev_cache_time
|
|
|
|
# Check in-memory cache
|
|
if (
|
|
not force_refresh
|
|
and _models_dev_cache
|
|
and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL
|
|
):
|
|
return _models_dev_cache
|
|
|
|
# Try network fetch
|
|
try:
|
|
response = requests.get(MODELS_DEV_URL, timeout=15)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
if isinstance(data, dict) and len(data) > 0:
|
|
_models_dev_cache = data
|
|
_models_dev_cache_time = time.time()
|
|
_save_disk_cache(data)
|
|
logger.debug(
|
|
"Fetched models.dev registry: %d providers, %d total models",
|
|
len(data),
|
|
sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)),
|
|
)
|
|
return data
|
|
except Exception as e:
|
|
logger.debug("Failed to fetch models.dev: %s", e)
|
|
|
|
# Fall back to disk cache — use a short TTL (5 min) so we retry
|
|
# the network fetch soon instead of serving stale data for a full hour.
|
|
if not _models_dev_cache:
|
|
_models_dev_cache = _load_disk_cache()
|
|
if _models_dev_cache:
|
|
_models_dev_cache_time = time.time() - _MODELS_DEV_CACHE_TTL + 300
|
|
logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache))
|
|
|
|
return _models_dev_cache
|
|
|
|
|
|
def lookup_models_dev_context(provider: str, model: str) -> Optional[int]:
|
|
"""Look up context_length for a provider+model combo in models.dev.
|
|
|
|
Returns the context window in tokens, or None if not found.
|
|
Handles case-insensitive matching and filters out context=0 entries.
|
|
"""
|
|
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
|
if not mdev_provider_id:
|
|
return None
|
|
|
|
data = fetch_models_dev()
|
|
provider_data = data.get(mdev_provider_id)
|
|
if not isinstance(provider_data, dict):
|
|
return None
|
|
|
|
models = provider_data.get("models", {})
|
|
if not isinstance(models, dict):
|
|
return None
|
|
|
|
# Exact match
|
|
entry = models.get(model)
|
|
if entry:
|
|
ctx = _extract_context(entry)
|
|
if ctx:
|
|
return ctx
|
|
|
|
# Case-insensitive match
|
|
model_lower = model.lower()
|
|
for mid, mdata in models.items():
|
|
if mid.lower() == model_lower:
|
|
ctx = _extract_context(mdata)
|
|
if ctx:
|
|
return ctx
|
|
|
|
return None
|
|
|
|
|
|
def _extract_context(entry: Dict[str, Any]) -> Optional[int]:
|
|
"""Extract context_length from a models.dev model entry.
|
|
|
|
Returns None for invalid/zero values (some audio/image models have context=0).
|
|
"""
|
|
if not isinstance(entry, dict):
|
|
return None
|
|
limit = entry.get("limit")
|
|
if not isinstance(limit, dict):
|
|
return None
|
|
ctx = limit.get("context")
|
|
if isinstance(ctx, (int, float)) and ctx > 0:
|
|
return int(ctx)
|
|
return None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model capability metadata
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@dataclass
|
|
class ModelCapabilities:
|
|
"""Structured capability metadata for a model from models.dev."""
|
|
|
|
supports_tools: bool = True
|
|
supports_vision: bool = False
|
|
supports_reasoning: bool = False
|
|
context_window: int = 200000
|
|
max_output_tokens: int = 8192
|
|
model_family: str = ""
|
|
|
|
|
|
def _get_provider_models(provider: str) -> Optional[Dict[str, Any]]:
|
|
"""Resolve a Hermes provider ID to its models dict from models.dev.
|
|
|
|
Returns the models dict or None if the provider is unknown or has no data.
|
|
"""
|
|
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
|
if not mdev_provider_id:
|
|
return None
|
|
|
|
data = fetch_models_dev()
|
|
provider_data = data.get(mdev_provider_id)
|
|
if not isinstance(provider_data, dict):
|
|
return None
|
|
|
|
models = provider_data.get("models", {})
|
|
if not isinstance(models, dict):
|
|
return None
|
|
|
|
return models
|
|
|
|
|
|
def _find_model_entry(models: Dict[str, Any], model: str) -> Optional[Dict[str, Any]]:
|
|
"""Find a model entry by exact match, then case-insensitive fallback."""
|
|
# Exact match
|
|
entry = models.get(model)
|
|
if isinstance(entry, dict):
|
|
return entry
|
|
|
|
# Case-insensitive match
|
|
model_lower = model.lower()
|
|
for mid, mdata in models.items():
|
|
if mid.lower() == model_lower and isinstance(mdata, dict):
|
|
return mdata
|
|
|
|
return None
|
|
|
|
|
|
def get_model_capabilities(provider: str, model: str) -> Optional[ModelCapabilities]:
|
|
"""Look up full capability metadata from models.dev cache.
|
|
|
|
Uses the existing fetch_models_dev() and PROVIDER_TO_MODELS_DEV mapping.
|
|
Returns None if model not found.
|
|
|
|
Extracts from model entry fields:
|
|
- reasoning (bool) → supports_reasoning
|
|
- tool_call (bool) → supports_tools
|
|
- attachment (bool) → supports_vision
|
|
- limit.context (int) → context_window
|
|
- limit.output (int) → max_output_tokens
|
|
- family (str) → model_family
|
|
"""
|
|
models = _get_provider_models(provider)
|
|
if models is None:
|
|
return None
|
|
|
|
entry = _find_model_entry(models, model)
|
|
if entry is None:
|
|
return None
|
|
|
|
# Extract capability flags (default to False if missing)
|
|
supports_tools = bool(entry.get("tool_call", False))
|
|
supports_vision = bool(entry.get("attachment", False))
|
|
supports_reasoning = bool(entry.get("reasoning", False))
|
|
|
|
# Extract limits
|
|
limit = entry.get("limit", {})
|
|
if not isinstance(limit, dict):
|
|
limit = {}
|
|
|
|
ctx = limit.get("context")
|
|
context_window = int(ctx) if isinstance(ctx, (int, float)) and ctx > 0 else 200000
|
|
|
|
out = limit.get("output")
|
|
max_output_tokens = int(out) if isinstance(out, (int, float)) and out > 0 else 8192
|
|
|
|
model_family = entry.get("family", "") or ""
|
|
|
|
return ModelCapabilities(
|
|
supports_tools=supports_tools,
|
|
supports_vision=supports_vision,
|
|
supports_reasoning=supports_reasoning,
|
|
context_window=context_window,
|
|
max_output_tokens=max_output_tokens,
|
|
model_family=model_family,
|
|
)
|
|
|
|
|
|
def list_provider_models(provider: str) -> List[str]:
|
|
"""Return all model IDs for a provider from models.dev.
|
|
|
|
Returns an empty list if the provider is unknown or has no data.
|
|
"""
|
|
models = _get_provider_models(provider)
|
|
if models is None:
|
|
return []
|
|
return list(models.keys())
|
|
|
|
|
|
# Patterns that indicate non-agentic or noise models (TTS, embedding,
|
|
# dated preview snapshots, live/streaming-only, image-only).
|
|
import re
|
|
_NOISE_PATTERNS: re.Pattern = re.compile(
|
|
r"-tts\b|embedding|live-|-(preview|exp)-\d{2,4}[-_]|"
|
|
r"-image\b|-image-preview\b|-customtools\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def list_agentic_models(provider: str) -> List[str]:
|
|
"""Return model IDs suitable for agentic use from models.dev.
|
|
|
|
Filters for tool_call=True and excludes noise (TTS, embedding,
|
|
dated preview snapshots, live/streaming, image-only models).
|
|
Returns an empty list on any failure.
|
|
"""
|
|
models = _get_provider_models(provider)
|
|
if models is None:
|
|
return []
|
|
|
|
result = []
|
|
for mid, entry in models.items():
|
|
if not isinstance(entry, dict):
|
|
continue
|
|
if not entry.get("tool_call", False):
|
|
continue
|
|
if _NOISE_PATTERNS.search(mid):
|
|
continue
|
|
result.append(mid)
|
|
return result
|
|
|
|
|
|
def search_models_dev(
|
|
query: str, provider: str = None, limit: int = 5
|
|
) -> List[Dict[str, Any]]:
|
|
"""Fuzzy search across models.dev catalog. Returns matching model entries.
|
|
|
|
Args:
|
|
query: Search string to match against model IDs.
|
|
provider: Optional Hermes provider ID to restrict search scope.
|
|
If None, searches across all providers in PROVIDER_TO_MODELS_DEV.
|
|
limit: Maximum number of results to return.
|
|
|
|
Returns:
|
|
List of dicts, each containing 'provider', 'model_id', and the full
|
|
model 'entry' from models.dev.
|
|
"""
|
|
data = fetch_models_dev()
|
|
if not data:
|
|
return []
|
|
|
|
# Build list of (provider_id, model_id, entry) candidates
|
|
candidates: List[tuple] = []
|
|
|
|
if provider is not None:
|
|
# Search only the specified provider
|
|
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
|
if not mdev_provider_id:
|
|
return []
|
|
provider_data = data.get(mdev_provider_id, {})
|
|
if isinstance(provider_data, dict):
|
|
models = provider_data.get("models", {})
|
|
if isinstance(models, dict):
|
|
for mid, mdata in models.items():
|
|
candidates.append((provider, mid, mdata))
|
|
else:
|
|
# Search across all mapped providers
|
|
for hermes_prov, mdev_prov in PROVIDER_TO_MODELS_DEV.items():
|
|
provider_data = data.get(mdev_prov, {})
|
|
if isinstance(provider_data, dict):
|
|
models = provider_data.get("models", {})
|
|
if isinstance(models, dict):
|
|
for mid, mdata in models.items():
|
|
candidates.append((hermes_prov, mid, mdata))
|
|
|
|
if not candidates:
|
|
return []
|
|
|
|
# Use difflib for fuzzy matching — case-insensitive comparison
|
|
model_ids_lower = [c[1].lower() for c in candidates]
|
|
query_lower = query.lower()
|
|
|
|
# First try exact substring matches (more intuitive than pure edit-distance)
|
|
substring_matches = []
|
|
for prov, mid, mdata in candidates:
|
|
if query_lower in mid.lower():
|
|
substring_matches.append({"provider": prov, "model_id": mid, "entry": mdata})
|
|
|
|
# Then add difflib fuzzy matches for any remaining slots
|
|
fuzzy_ids = difflib.get_close_matches(
|
|
query_lower, model_ids_lower, n=limit * 2, cutoff=0.4
|
|
)
|
|
|
|
seen_ids: set = set()
|
|
results: List[Dict[str, Any]] = []
|
|
|
|
# Prioritize substring matches
|
|
for match in substring_matches:
|
|
key = (match["provider"], match["model_id"])
|
|
if key not in seen_ids:
|
|
seen_ids.add(key)
|
|
results.append(match)
|
|
if len(results) >= limit:
|
|
return results
|
|
|
|
# Add fuzzy matches
|
|
for fid in fuzzy_ids:
|
|
# Find original-case candidates matching this lowered ID
|
|
for prov, mid, mdata in candidates:
|
|
if mid.lower() == fid:
|
|
key = (prov, mid)
|
|
if key not in seen_ids:
|
|
seen_ids.add(key)
|
|
results.append({"provider": prov, "model_id": mid, "entry": mdata})
|
|
if len(results) >= limit:
|
|
return results
|
|
|
|
return results
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Rich dataclass constructors — parse raw models.dev JSON into dataclasses
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def _parse_model_info(model_id: str, raw: Dict[str, Any], provider_id: str) -> ModelInfo:
|
|
"""Convert a raw models.dev model entry dict into a ModelInfo dataclass."""
|
|
limit = raw.get("limit") or {}
|
|
if not isinstance(limit, dict):
|
|
limit = {}
|
|
|
|
cost = raw.get("cost") or {}
|
|
if not isinstance(cost, dict):
|
|
cost = {}
|
|
|
|
modalities = raw.get("modalities") or {}
|
|
if not isinstance(modalities, dict):
|
|
modalities = {}
|
|
|
|
input_mods = modalities.get("input") or []
|
|
output_mods = modalities.get("output") or []
|
|
|
|
ctx = limit.get("context")
|
|
ctx_int = int(ctx) if isinstance(ctx, (int, float)) and ctx > 0 else 0
|
|
out = limit.get("output")
|
|
out_int = int(out) if isinstance(out, (int, float)) and out > 0 else 0
|
|
inp = limit.get("input")
|
|
inp_int = int(inp) if isinstance(inp, (int, float)) and inp > 0 else None
|
|
|
|
return ModelInfo(
|
|
id=model_id,
|
|
name=raw.get("name", "") or model_id,
|
|
family=raw.get("family", "") or "",
|
|
provider_id=provider_id,
|
|
reasoning=bool(raw.get("reasoning", False)),
|
|
tool_call=bool(raw.get("tool_call", False)),
|
|
attachment=bool(raw.get("attachment", False)),
|
|
temperature=bool(raw.get("temperature", False)),
|
|
structured_output=bool(raw.get("structured_output", False)),
|
|
open_weights=bool(raw.get("open_weights", False)),
|
|
input_modalities=tuple(input_mods) if isinstance(input_mods, list) else (),
|
|
output_modalities=tuple(output_mods) if isinstance(output_mods, list) else (),
|
|
context_window=ctx_int,
|
|
max_output=out_int,
|
|
max_input=inp_int,
|
|
cost_input=float(cost.get("input", 0) or 0),
|
|
cost_output=float(cost.get("output", 0) or 0),
|
|
cost_cache_read=float(cost["cache_read"]) if "cache_read" in cost and cost["cache_read"] is not None else None,
|
|
cost_cache_write=float(cost["cache_write"]) if "cache_write" in cost and cost["cache_write"] is not None else None,
|
|
knowledge_cutoff=raw.get("knowledge", "") or "",
|
|
release_date=raw.get("release_date", "") or "",
|
|
status=raw.get("status", "") or "",
|
|
interleaved=raw.get("interleaved", False),
|
|
)
|
|
|
|
|
|
def _parse_provider_info(provider_id: str, raw: Dict[str, Any]) -> ProviderInfo:
|
|
"""Convert a raw models.dev provider entry dict into a ProviderInfo."""
|
|
env = raw.get("env") or []
|
|
models = raw.get("models") or {}
|
|
return ProviderInfo(
|
|
id=provider_id,
|
|
name=raw.get("name", "") or provider_id,
|
|
env=tuple(env) if isinstance(env, list) else (),
|
|
api=raw.get("api", "") or "",
|
|
doc=raw.get("doc", "") or "",
|
|
model_count=len(models) if isinstance(models, dict) else 0,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Provider-level queries
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_provider_info(provider_id: str) -> Optional[ProviderInfo]:
|
|
"""Get full provider metadata from models.dev.
|
|
|
|
Accepts either a Hermes provider ID (e.g. "kilocode") or a models.dev
|
|
ID (e.g. "kilo"). Returns None if the provider is not in the catalog.
|
|
"""
|
|
# Resolve Hermes ID → models.dev ID
|
|
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
|
|
|
|
data = fetch_models_dev()
|
|
raw = data.get(mdev_id)
|
|
if not isinstance(raw, dict):
|
|
return None
|
|
|
|
return _parse_provider_info(mdev_id, raw)
|
|
|
|
|
|
def list_all_providers() -> Dict[str, ProviderInfo]:
|
|
"""Return all providers from models.dev as {provider_id: ProviderInfo}.
|
|
|
|
Returns the full catalog — 109+ providers. For providers that have
|
|
a Hermes alias, both the models.dev ID and the Hermes ID are included.
|
|
"""
|
|
data = fetch_models_dev()
|
|
result: Dict[str, ProviderInfo] = {}
|
|
|
|
for pid, pdata in data.items():
|
|
if isinstance(pdata, dict):
|
|
info = _parse_provider_info(pid, pdata)
|
|
result[pid] = info
|
|
|
|
return result
|
|
|
|
|
|
def get_providers_for_env_var(env_var: str) -> List[str]:
|
|
"""Reverse lookup: find all providers that use a given env var.
|
|
|
|
Useful for auto-detection: "user has ANTHROPIC_API_KEY set, which
|
|
providers does that enable?"
|
|
|
|
Returns list of models.dev provider IDs.
|
|
"""
|
|
data = fetch_models_dev()
|
|
matches: List[str] = []
|
|
|
|
for pid, pdata in data.items():
|
|
if isinstance(pdata, dict):
|
|
env = pdata.get("env", [])
|
|
if isinstance(env, list) and env_var in env:
|
|
matches.append(pid)
|
|
|
|
return matches
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model-level queries (rich ModelInfo)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def get_model_info(
|
|
provider_id: str, model_id: str
|
|
) -> Optional[ModelInfo]:
|
|
"""Get full model metadata from models.dev.
|
|
|
|
Accepts Hermes or models.dev provider ID. Tries exact match then
|
|
case-insensitive fallback. Returns None if not found.
|
|
"""
|
|
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
|
|
|
|
data = fetch_models_dev()
|
|
pdata = data.get(mdev_id)
|
|
if not isinstance(pdata, dict):
|
|
return None
|
|
|
|
models = pdata.get("models", {})
|
|
if not isinstance(models, dict):
|
|
return None
|
|
|
|
# Exact match
|
|
raw = models.get(model_id)
|
|
if isinstance(raw, dict):
|
|
return _parse_model_info(model_id, raw, mdev_id)
|
|
|
|
# Case-insensitive fallback
|
|
model_lower = model_id.lower()
|
|
for mid, mdata in models.items():
|
|
if mid.lower() == model_lower and isinstance(mdata, dict):
|
|
return _parse_model_info(mid, mdata, mdev_id)
|
|
|
|
return None
|
|
|
|
|
|
def get_model_info_any_provider(model_id: str) -> Optional[ModelInfo]:
|
|
"""Search all providers for a model by ID.
|
|
|
|
Useful when you have a full slug like "anthropic/claude-sonnet-4.6" or
|
|
a bare name and want to find it anywhere. Checks Hermes-mapped providers
|
|
first, then falls back to all models.dev providers.
|
|
"""
|
|
data = fetch_models_dev()
|
|
|
|
# Try Hermes-mapped providers first (more likely what the user wants)
|
|
for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items():
|
|
pdata = data.get(mdev_id)
|
|
if not isinstance(pdata, dict):
|
|
continue
|
|
models = pdata.get("models", {})
|
|
if not isinstance(models, dict):
|
|
continue
|
|
|
|
raw = models.get(model_id)
|
|
if isinstance(raw, dict):
|
|
return _parse_model_info(model_id, raw, mdev_id)
|
|
|
|
# Case-insensitive
|
|
model_lower = model_id.lower()
|
|
for mid, mdata in models.items():
|
|
if mid.lower() == model_lower and isinstance(mdata, dict):
|
|
return _parse_model_info(mid, mdata, mdev_id)
|
|
|
|
# Fall back to ALL providers
|
|
for pid, pdata in data.items():
|
|
if pid in _get_reverse_mapping():
|
|
continue # already checked
|
|
if not isinstance(pdata, dict):
|
|
continue
|
|
models = pdata.get("models", {})
|
|
if not isinstance(models, dict):
|
|
continue
|
|
|
|
raw = models.get(model_id)
|
|
if isinstance(raw, dict):
|
|
return _parse_model_info(model_id, raw, pid)
|
|
|
|
return None
|
|
|
|
|
|
def list_provider_model_infos(provider_id: str) -> List[ModelInfo]:
|
|
"""Return all models for a provider as ModelInfo objects.
|
|
|
|
Filters out deprecated models by default.
|
|
"""
|
|
mdev_id = PROVIDER_TO_MODELS_DEV.get(provider_id, provider_id)
|
|
|
|
data = fetch_models_dev()
|
|
pdata = data.get(mdev_id)
|
|
if not isinstance(pdata, dict):
|
|
return []
|
|
|
|
models = pdata.get("models", {})
|
|
if not isinstance(models, dict):
|
|
return []
|
|
|
|
result: List[ModelInfo] = []
|
|
for mid, mdata in models.items():
|
|
if not isinstance(mdata, dict):
|
|
continue
|
|
status = mdata.get("status", "")
|
|
if status == "deprecated":
|
|
continue
|
|
result.append(_parse_model_info(mid, mdata, mdev_id))
|
|
|
|
return result
|