Use atomic_json_write() from utils.py instead of plain open()/json.dump()
for the models.dev disk cache. Prevents corrupted cache if the process is
killed mid-write — _load_disk_cache() silently returns {} on corrupt JSON,
losing all model metadata until the next successful API fetch.
Co-authored-by: memosr <memosr@users.noreply.github.com>
172 lines
5.3 KiB
Python
172 lines
5.3 KiB
Python
"""Models.dev registry integration for provider-aware context length detection.
|
|
|
|
Fetches model metadata from https://models.dev/api.json — a community-maintained
|
|
database of 3800+ models across 100+ providers, including per-provider context
|
|
windows, pricing, and capabilities.
|
|
|
|
Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json)
|
|
to avoid cold-start network latency.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
from utils import atomic_json_write
|
|
|
|
import requests
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
MODELS_DEV_URL = "https://models.dev/api.json"
|
|
_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory
|
|
|
|
# In-memory cache
|
|
_models_dev_cache: Dict[str, Any] = {}
|
|
_models_dev_cache_time: float = 0
|
|
|
|
# Provider ID mapping: Hermes provider names → models.dev provider IDs
|
|
PROVIDER_TO_MODELS_DEV: Dict[str, str] = {
|
|
"openrouter": "openrouter",
|
|
"anthropic": "anthropic",
|
|
"zai": "zai",
|
|
"kimi-coding": "kimi-for-coding",
|
|
"minimax": "minimax",
|
|
"minimax-cn": "minimax-cn",
|
|
"deepseek": "deepseek",
|
|
"alibaba": "alibaba",
|
|
"copilot": "github-copilot",
|
|
"ai-gateway": "vercel",
|
|
"opencode-zen": "opencode",
|
|
"opencode-go": "opencode-go",
|
|
"kilocode": "kilo",
|
|
}
|
|
|
|
|
|
def _get_cache_path() -> Path:
|
|
"""Return path to disk cache file."""
|
|
env_val = os.environ.get("HERMES_HOME", "")
|
|
hermes_home = Path(env_val) if env_val else Path.home() / ".hermes"
|
|
return hermes_home / "models_dev_cache.json"
|
|
|
|
|
|
def _load_disk_cache() -> Dict[str, Any]:
|
|
"""Load models.dev data from disk cache."""
|
|
try:
|
|
cache_path = _get_cache_path()
|
|
if cache_path.exists():
|
|
with open(cache_path, encoding="utf-8") as f:
|
|
return json.load(f)
|
|
except Exception as e:
|
|
logger.debug("Failed to load models.dev disk cache: %s", e)
|
|
return {}
|
|
|
|
|
|
def _save_disk_cache(data: Dict[str, Any]) -> None:
|
|
"""Save models.dev data to disk cache atomically."""
|
|
try:
|
|
cache_path = _get_cache_path()
|
|
atomic_json_write(cache_path, data, indent=None, separators=(",", ":"))
|
|
except Exception as e:
|
|
logger.debug("Failed to save models.dev disk cache: %s", e)
|
|
|
|
|
|
def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]:
|
|
"""Fetch models.dev registry. In-memory cache (1hr) + disk fallback.
|
|
|
|
Returns the full registry dict keyed by provider ID, or empty dict on failure.
|
|
"""
|
|
global _models_dev_cache, _models_dev_cache_time
|
|
|
|
# Check in-memory cache
|
|
if (
|
|
not force_refresh
|
|
and _models_dev_cache
|
|
and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL
|
|
):
|
|
return _models_dev_cache
|
|
|
|
# Try network fetch
|
|
try:
|
|
response = requests.get(MODELS_DEV_URL, timeout=15)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
if isinstance(data, dict) and len(data) > 0:
|
|
_models_dev_cache = data
|
|
_models_dev_cache_time = time.time()
|
|
_save_disk_cache(data)
|
|
logger.debug(
|
|
"Fetched models.dev registry: %d providers, %d total models",
|
|
len(data),
|
|
sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)),
|
|
)
|
|
return data
|
|
except Exception as e:
|
|
logger.debug("Failed to fetch models.dev: %s", e)
|
|
|
|
# Fall back to disk cache — use a short TTL (5 min) so we retry
|
|
# the network fetch soon instead of serving stale data for a full hour.
|
|
if not _models_dev_cache:
|
|
_models_dev_cache = _load_disk_cache()
|
|
if _models_dev_cache:
|
|
_models_dev_cache_time = time.time() - _MODELS_DEV_CACHE_TTL + 300
|
|
logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache))
|
|
|
|
return _models_dev_cache
|
|
|
|
|
|
def lookup_models_dev_context(provider: str, model: str) -> Optional[int]:
|
|
"""Look up context_length for a provider+model combo in models.dev.
|
|
|
|
Returns the context window in tokens, or None if not found.
|
|
Handles case-insensitive matching and filters out context=0 entries.
|
|
"""
|
|
mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider)
|
|
if not mdev_provider_id:
|
|
return None
|
|
|
|
data = fetch_models_dev()
|
|
provider_data = data.get(mdev_provider_id)
|
|
if not isinstance(provider_data, dict):
|
|
return None
|
|
|
|
models = provider_data.get("models", {})
|
|
if not isinstance(models, dict):
|
|
return None
|
|
|
|
# Exact match
|
|
entry = models.get(model)
|
|
if entry:
|
|
ctx = _extract_context(entry)
|
|
if ctx:
|
|
return ctx
|
|
|
|
# Case-insensitive match
|
|
model_lower = model.lower()
|
|
for mid, mdata in models.items():
|
|
if mid.lower() == model_lower:
|
|
ctx = _extract_context(mdata)
|
|
if ctx:
|
|
return ctx
|
|
|
|
return None
|
|
|
|
|
|
def _extract_context(entry: Dict[str, Any]) -> Optional[int]:
|
|
"""Extract context_length from a models.dev model entry.
|
|
|
|
Returns None for invalid/zero values (some audio/image models have context=0).
|
|
"""
|
|
if not isinstance(entry, dict):
|
|
return None
|
|
limit = entry.get("limit")
|
|
if not isinstance(limit, dict):
|
|
return None
|
|
ctx = limit.get("context")
|
|
if isinstance(ctx, (int, float)) and ctx > 0:
|
|
return int(ctx)
|
|
return None
|