diff --git a/AGENTS.md b/AGENTS.md index 181547eb4..fa733bc00 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -23,6 +23,7 @@ hermes-agent/ │ ├── prompt_caching.py # Anthropic prompt caching │ ├── auxiliary_client.py # Auxiliary LLM client (vision, summarization) │ ├── model_metadata.py # Model context lengths, token estimation +│ ├── models_dev.py # models.dev registry integration (provider-aware context) │ ├── display.py # KawaiiSpinner, tool preview formatting │ ├── skill_commands.py # Skill slash commands (shared CLI/gateway) │ └── trajectory.py # Trajectory saving helpers diff --git a/agent/context_compressor.py b/agent/context_compressor.py index 6ba935505..586d22626 100644 --- a/agent/context_compressor.py +++ b/agent/context_compressor.py @@ -47,10 +47,12 @@ class ContextCompressor: base_url: str = "", api_key: str = "", config_context_length: int | None = None, + provider: str = "", ): self.model = model self.base_url = base_url self.api_key = api_key + self.provider = provider self.threshold_percent = threshold_percent self.protect_first_n = protect_first_n self.protect_last_n = protect_last_n @@ -60,6 +62,7 @@ class ContextCompressor: self.context_length = get_model_context_length( model, base_url=base_url, api_key=api_key, config_context_length=config_context_length, + provider=provider, ) self.threshold_tokens = int(self.context_length * threshold_percent) self.compression_count = 0 diff --git a/agent/model_metadata.py b/agent/model_metadata.py index 758ca0520..9ed6c4a2b 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -55,104 +55,52 @@ _endpoint_model_metadata_cache_time: Dict[str, float] = {} _ENDPOINT_MODEL_CACHE_TTL = 300 # Descending tiers for context length probing when the model is unknown. -# We start high and step down on context-length errors until one works. +# We start at 128K (a safe default for most modern models) and step down +# on context-length errors until one works. CONTEXT_PROBE_TIERS = [ - 2_000_000, - 1_000_000, - 512_000, - 200_000, 128_000, 64_000, 32_000, + 16_000, + 8_000, ] +# Default context length when no detection method succeeds. +DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0] + +# Thin fallback defaults — only broad model family patterns. +# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic +# all miss. Replaced the previous 80+ entry dict. +# For provider-specific context lengths, models.dev is the primary source. DEFAULT_CONTEXT_LENGTHS = { - "anthropic/claude-opus-4": 200000, - "anthropic/claude-opus-4.5": 200000, - "anthropic/claude-opus-4.6": 1000000, - "anthropic/claude-sonnet-4": 200000, - "anthropic/claude-sonnet-4-20250514": 200000, - "anthropic/claude-sonnet-4.5": 200000, - "anthropic/claude-sonnet-4.6": 1000000, - "anthropic/claude-haiku-4.5": 200000, - # Bare Anthropic model IDs (for native API provider) + # Anthropic Claude 4.6 (1M context) — bare IDs only to avoid + # fuzzy-match collisions (e.g. "anthropic/claude-sonnet-4" is a + # substring of "anthropic/claude-sonnet-4.6"). + # OpenRouter-prefixed models resolve via OpenRouter live API or models.dev. "claude-opus-4-6": 1000000, "claude-sonnet-4-6": 1000000, - "claude-opus-4-5-20251101": 200000, - "claude-sonnet-4-5-20250929": 200000, - "claude-opus-4-1-20250805": 200000, - "claude-opus-4-20250514": 200000, - "claude-sonnet-4-20250514": 200000, - "claude-haiku-4-5-20251001": 200000, - "openai/gpt-5": 128000, - "openai/gpt-4.1": 1047576, - "openai/gpt-4.1-mini": 1047576, - "openai/gpt-4o": 128000, - "openai/gpt-4-turbo": 128000, - "openai/gpt-4o-mini": 128000, - "google/gemini-3-pro-preview": 1048576, - "google/gemini-3-flash": 1048576, - "google/gemini-2.5-flash": 1048576, - "google/gemini-2.0-flash": 1048576, - "google/gemini-2.5-pro": 1048576, - "deepseek/deepseek-v3.2": 65536, - "meta-llama/llama-3.3-70b-instruct": 131072, - "deepseek/deepseek-chat-v3": 65536, - "qwen/qwen-2.5-72b-instruct": 32768, - "glm-4.7": 202752, - "glm-5": 202752, - "glm-4.5": 131072, - "glm-4.5-flash": 131072, - "kimi-for-coding": 262144, - "kimi-k2.5": 262144, - "kimi-k2-thinking": 262144, - "kimi-k2-thinking-turbo": 262144, - "kimi-k2-turbo-preview": 262144, - "kimi-k2-0905-preview": 131072, - "MiniMax-M2.7": 204800, - "MiniMax-M2.7-highspeed": 204800, - "MiniMax-M2.5": 204800, - "MiniMax-M2.5-highspeed": 204800, - "MiniMax-M2.1": 204800, - # OpenCode Zen models - "gpt-5.4-pro": 128000, - "gpt-5.4": 128000, - "gpt-5.3-codex": 128000, - "gpt-5.3-codex-spark": 128000, - "gpt-5.2": 128000, - "gpt-5.2-codex": 128000, - "gpt-5.1": 128000, - "gpt-5.1-codex": 128000, - "gpt-5.1-codex-max": 128000, - "gpt-5.1-codex-mini": 128000, + "claude-opus-4.6": 1000000, + "claude-sonnet-4.6": 1000000, + # Catch-all for older Claude models (must sort after specific entries) + "claude": 200000, + # OpenAI + "gpt-4.1": 1047576, "gpt-5": 128000, - "gpt-5-codex": 128000, - "gpt-5-nano": 128000, - # Bare model IDs without provider prefix (avoid duplicates with entries above) - "claude-opus-4-5": 200000, - "claude-opus-4-1": 200000, - "claude-sonnet-4-5": 200000, - "claude-sonnet-4": 200000, - "claude-haiku-4-5": 200000, - "claude-3-5-haiku": 200000, - "gemini-3.1-pro": 1048576, - "gemini-3-pro": 1048576, - "gemini-3-flash": 1048576, - "minimax-m2.5": 204800, - "minimax-m2.5-free": 204800, - "minimax-m2.1": 204800, - "glm-4.6": 202752, - "kimi-k2": 262144, - "qwen3-coder": 32768, - "big-pickle": 128000, - # Alibaba Cloud / DashScope Qwen models - "qwen3.5-plus": 131072, - "qwen3-max": 131072, - "qwen3-coder-plus": 131072, - "qwen3-coder-next": 131072, - "qwen-plus-latest": 131072, - "qwen3.5-flash": 131072, - "qwen-vl-max": 32768, + "gpt-4": 128000, + # Google + "gemini": 1048576, + # DeepSeek + "deepseek": 128000, + # Meta + "llama": 131072, + # Qwen + "qwen": 131072, + # MiniMax + "minimax": 204800, + # GLM + "glm": 202752, + # Kimi + "kimi": 262144, } _CONTEXT_LENGTH_KEYS = ( @@ -693,22 +641,100 @@ def _query_local_context_length(model: str, base_url: str) -> Optional[int]: return None +def _normalize_model_version(model: str) -> str: + """Normalize version separators for matching. + + Nous uses dashes: claude-opus-4-6, claude-sonnet-4-5 + OpenRouter uses dots: claude-opus-4.6, claude-sonnet-4.5 + Normalize both to dashes for comparison. + """ + return model.replace(".", "-") + + +def _query_anthropic_context_length(model: str, base_url: str, api_key: str) -> Optional[int]: + """Query Anthropic's /v1/models endpoint for context length. + + Only works with regular ANTHROPIC_API_KEY (sk-ant-api*). + OAuth tokens (sk-ant-oat*) from Claude Code return 401. + """ + if not api_key or api_key.startswith("sk-ant-oat"): + return None # OAuth tokens can't access /v1/models + try: + base = base_url.rstrip("/") + if base.endswith("/v1"): + base = base[:-3] + url = f"{base}/v1/models?limit=1000" + headers = { + "x-api-key": api_key, + "anthropic-version": "2023-06-01", + } + resp = requests.get(url, headers=headers, timeout=10) + if resp.status_code != 200: + return None + data = resp.json() + for m in data.get("data", []): + if m.get("id") == model: + ctx = m.get("max_input_tokens") + if isinstance(ctx, int) and ctx > 0: + return ctx + except Exception as e: + logger.debug("Anthropic /v1/models query failed: %s", e) + return None + + +def _resolve_nous_context_length(model: str) -> Optional[int]: + """Resolve Nous Portal model context length via OpenRouter metadata. + + Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses + prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching + with version normalization (dot↔dash). + """ + metadata = fetch_model_metadata() # OpenRouter cache + # Exact match first + if model in metadata: + return metadata[model].get("context_length") + + normalized = _normalize_model_version(model).lower() + + for or_id, entry in metadata.items(): + bare = or_id.split("/", 1)[1] if "/" in or_id else or_id + if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized: + return entry.get("context_length") + + # Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview + # Require match to be at a word boundary (followed by -, :, or end of string) + model_lower = model.lower() + for or_id, entry in metadata.items(): + bare = or_id.split("/", 1)[1] if "/" in or_id else or_id + for candidate, query in [(bare.lower(), model_lower), (_normalize_model_version(bare).lower(), normalized)]: + if candidate.startswith(query) and ( + len(candidate) == len(query) or candidate[len(query)] in "-:." + ): + return entry.get("context_length") + + return None + + def get_model_context_length( model: str, base_url: str = "", api_key: str = "", config_context_length: int | None = None, + provider: str = "", ) -> int: """Get the context length for a model. Resolution order: - 0. Explicit config override (model.context_length in config.yaml) + 0. Explicit config override (model.context_length or custom_providers per-model) 1. Persistent cache (previously discovered via probing) 2. Active endpoint metadata (/models for explicit custom endpoints) - 3. Local server query (for local endpoints when model not in /models list) - 4. OpenRouter API metadata - 5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only) - 6. First probe tier (2M) — will be narrowed on first context error + 3. Local server query (for local endpoints) + 4. Anthropic /v1/models API (API-key users only, not OAuth) + 5. OpenRouter live API metadata + 6. Nous suffix-match via OpenRouter cache + 7. models.dev registry lookup (provider-aware) + 8. Thin hardcoded defaults (broad family patterns) + 9. Default fallback (128K) """ # 0. Explicit config override — user knows best if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0: @@ -744,9 +770,7 @@ def get_model_context_length( if isinstance(context_length, int): return context_length if not _is_known_provider_base_url(base_url): - # Explicit third-party endpoints should not borrow fuzzy global - # defaults from unrelated providers with similarly named models. - # But first try querying the local server directly. + # 3. Try querying local server directly if is_local_endpoint(base_url): local_ctx = _query_local_context_length(model, base_url) if local_ctx and local_ctx > 0: @@ -756,31 +780,53 @@ def get_model_context_length( "Could not detect context length for model %r at %s — " "defaulting to %s tokens (probe-down). Set model.context_length " "in config.yaml to override.", - model, base_url, f"{CONTEXT_PROBE_TIERS[0]:,}", + model, base_url, f"{DEFAULT_FALLBACK_CONTEXT:,}", ) - return CONTEXT_PROBE_TIERS[0] + return DEFAULT_FALLBACK_CONTEXT - # 3. OpenRouter API metadata + # 4. Anthropic /v1/models API (only for regular API keys, not OAuth) + if provider == "anthropic" or ( + base_url and "api.anthropic.com" in base_url + ): + ctx = _query_anthropic_context_length(model, base_url or "https://api.anthropic.com", api_key) + if ctx: + return ctx + + # 5. Provider-aware lookups (before generic OpenRouter cache) + # These are provider-specific and take priority over the generic OR cache, + # since the same model can have different context limits per provider + # (e.g. claude-opus-4.6 is 1M on Anthropic but 128K on GitHub Copilot). + if provider == "nous": + ctx = _resolve_nous_context_length(model) + if ctx: + return ctx + elif provider: + from agent.models_dev import lookup_models_dev_context + ctx = lookup_models_dev_context(provider, model) + if ctx: + return ctx + + # 6. OpenRouter live API metadata (provider-unaware fallback) metadata = fetch_model_metadata() if model in metadata: return metadata[model].get("context_length", 128000) - # 4. Hardcoded defaults (fuzzy match — longest key first for specificity) + # 8. Hardcoded defaults (fuzzy match — longest key first for specificity) for default_model, length in sorted( DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True ): if default_model in model or model in default_model: return length - # 5. Query local server for unknown models before defaulting to 2M + # 9. Query local server as last resort if base_url and is_local_endpoint(base_url): local_ctx = _query_local_context_length(model, base_url) if local_ctx and local_ctx > 0: save_context_length(model, base_url, local_ctx) return local_ctx - # 6. Unknown model — start at highest probe tier - return CONTEXT_PROBE_TIERS[0] + # 10. Default fallback — 128K + return DEFAULT_FALLBACK_CONTEXT def estimate_tokens_rough(text: str) -> int: diff --git a/agent/models_dev.py b/agent/models_dev.py new file mode 100644 index 000000000..b564db8ef --- /dev/null +++ b/agent/models_dev.py @@ -0,0 +1,170 @@ +"""Models.dev registry integration for provider-aware context length detection. + +Fetches model metadata from https://models.dev/api.json — a community-maintained +database of 3800+ models across 100+ providers, including per-provider context +windows, pricing, and capabilities. + +Data is cached in memory (1hr TTL) and on disk (~/.hermes/models_dev_cache.json) +to avoid cold-start network latency. +""" + +import json +import logging +import os +import time +from pathlib import Path +from typing import Any, Dict, Optional + +import requests + +logger = logging.getLogger(__name__) + +MODELS_DEV_URL = "https://models.dev/api.json" +_MODELS_DEV_CACHE_TTL = 3600 # 1 hour in-memory + +# In-memory cache +_models_dev_cache: Dict[str, Any] = {} +_models_dev_cache_time: float = 0 + +# Provider ID mapping: Hermes provider names → models.dev provider IDs +PROVIDER_TO_MODELS_DEV: Dict[str, str] = { + "openrouter": "openrouter", + "anthropic": "anthropic", + "zai": "zai", + "kimi-coding": "kimi-for-coding", + "minimax": "minimax", + "minimax-cn": "minimax-cn", + "deepseek": "deepseek", + "alibaba": "alibaba", + "copilot": "github-copilot", + "ai-gateway": "vercel", + "opencode-zen": "opencode", + "opencode-go": "opencode-go", + "kilocode": "kilo", +} + + +def _get_cache_path() -> Path: + """Return path to disk cache file.""" + env_val = os.environ.get("HERMES_HOME", "") + hermes_home = Path(env_val) if env_val else Path.home() / ".hermes" + return hermes_home / "models_dev_cache.json" + + +def _load_disk_cache() -> Dict[str, Any]: + """Load models.dev data from disk cache.""" + try: + cache_path = _get_cache_path() + if cache_path.exists(): + with open(cache_path, encoding="utf-8") as f: + return json.load(f) + except Exception as e: + logger.debug("Failed to load models.dev disk cache: %s", e) + return {} + + +def _save_disk_cache(data: Dict[str, Any]) -> None: + """Save models.dev data to disk cache.""" + try: + cache_path = _get_cache_path() + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w", encoding="utf-8") as f: + json.dump(data, f, separators=(",", ":")) + except Exception as e: + logger.debug("Failed to save models.dev disk cache: %s", e) + + +def fetch_models_dev(force_refresh: bool = False) -> Dict[str, Any]: + """Fetch models.dev registry. In-memory cache (1hr) + disk fallback. + + Returns the full registry dict keyed by provider ID, or empty dict on failure. + """ + global _models_dev_cache, _models_dev_cache_time + + # Check in-memory cache + if ( + not force_refresh + and _models_dev_cache + and (time.time() - _models_dev_cache_time) < _MODELS_DEV_CACHE_TTL + ): + return _models_dev_cache + + # Try network fetch + try: + response = requests.get(MODELS_DEV_URL, timeout=15) + response.raise_for_status() + data = response.json() + if isinstance(data, dict) and len(data) > 0: + _models_dev_cache = data + _models_dev_cache_time = time.time() + _save_disk_cache(data) + logger.debug( + "Fetched models.dev registry: %d providers, %d total models", + len(data), + sum(len(p.get("models", {})) for p in data.values() if isinstance(p, dict)), + ) + return data + except Exception as e: + logger.debug("Failed to fetch models.dev: %s", e) + + # Fall back to disk cache + if not _models_dev_cache: + _models_dev_cache = _load_disk_cache() + if _models_dev_cache: + _models_dev_cache_time = time.time() + logger.debug("Loaded models.dev from disk cache (%d providers)", len(_models_dev_cache)) + + return _models_dev_cache + + +def lookup_models_dev_context(provider: str, model: str) -> Optional[int]: + """Look up context_length for a provider+model combo in models.dev. + + Returns the context window in tokens, or None if not found. + Handles case-insensitive matching and filters out context=0 entries. + """ + mdev_provider_id = PROVIDER_TO_MODELS_DEV.get(provider) + if not mdev_provider_id: + return None + + data = fetch_models_dev() + provider_data = data.get(mdev_provider_id) + if not isinstance(provider_data, dict): + return None + + models = provider_data.get("models", {}) + if not isinstance(models, dict): + return None + + # Exact match + entry = models.get(model) + if entry: + ctx = _extract_context(entry) + if ctx: + return ctx + + # Case-insensitive match + model_lower = model.lower() + for mid, mdata in models.items(): + if mid.lower() == model_lower: + ctx = _extract_context(mdata) + if ctx: + return ctx + + return None + + +def _extract_context(entry: Dict[str, Any]) -> Optional[int]: + """Extract context_length from a models.dev model entry. + + Returns None for invalid/zero values (some audio/image models have context=0). + """ + if not isinstance(entry, dict): + return None + limit = entry.get("limit") + if not isinstance(limit, dict): + return None + ctx = limit.get("context") + if isinstance(ctx, (int, float)) and ctx > 0: + return int(ctx) + return None diff --git a/hermes_cli/main.py b/hermes_cli/main.py index 48fd2d0cd..33d3a0601 100644 --- a/hermes_cli/main.py +++ b/hermes_cli/main.py @@ -1137,10 +1137,21 @@ def _model_flow_custom(config): base_url = input(f"API base URL [{current_url or 'e.g. https://api.example.com/v1'}]: ").strip() api_key = input(f"API key [{current_key[:8] + '...' if current_key else 'optional'}]: ").strip() model_name = input("Model name (e.g. gpt-4, llama-3-70b): ").strip() + context_length_str = input("Context length in tokens [leave blank for auto-detect]: ").strip() except (KeyboardInterrupt, EOFError): print("\nCancelled.") return + context_length = None + if context_length_str: + try: + context_length = int(context_length_str.replace(",", "").replace("k", "000").replace("K", "000")) + if context_length <= 0: + context_length = None + except ValueError: + print(f"Invalid context length: {context_length_str} — will auto-detect.") + context_length = None + if not base_url and not current_url: print("No URL provided. Cancelled.") return @@ -1203,14 +1214,14 @@ def _model_flow_custom(config): print("Endpoint saved. Use `/model` in chat or `hermes model` to set a model.") # Auto-save to custom_providers so it appears in the menu next time - _save_custom_provider(effective_url, effective_key, model_name or "") + _save_custom_provider(effective_url, effective_key, model_name or "", context_length=context_length) -def _save_custom_provider(base_url, api_key="", model=""): +def _save_custom_provider(base_url, api_key="", model="", context_length=None): """Save a custom endpoint to custom_providers in config.yaml. Deduplicates by base_url — if the URL already exists, updates the - model name but doesn't add a duplicate entry. + model name and context_length but doesn't add a duplicate entry. Auto-generates a display name from the URL hostname. """ from hermes_cli.config import load_config, save_config @@ -1220,14 +1231,24 @@ def _save_custom_provider(base_url, api_key="", model=""): if not isinstance(providers, list): providers = [] - # Check if this URL is already saved — update model if so + # Check if this URL is already saved — update model/context_length if so for entry in providers: if isinstance(entry, dict) and entry.get("base_url", "").rstrip("/") == base_url.rstrip("/"): + changed = False if model and entry.get("model") != model: entry["model"] = model + changed = True + if model and context_length: + models_cfg = entry.get("models", {}) + if not isinstance(models_cfg, dict): + models_cfg = {} + models_cfg[model] = {"context_length": context_length} + entry["models"] = models_cfg + changed = True + if changed: cfg["custom_providers"] = providers save_config(cfg) - return # already saved, updated model if needed + return # already saved, updated if needed # Auto-generate a name from the URL import re @@ -1249,6 +1270,8 @@ def _save_custom_provider(base_url, api_key="", model=""): entry["api_key"] = api_key if model: entry["model"] = model + if model and context_length: + entry["models"] = {model: {"context_length": context_length}} providers.append(entry) cfg["custom_providers"] = providers diff --git a/hermes_cli/setup.py b/hermes_cli/setup.py index ac21ec8dd..5d114885d 100644 --- a/hermes_cli/setup.py +++ b/hermes_cli/setup.py @@ -1045,93 +1045,17 @@ def setup_model_provider(config: dict): print() print_header("Custom OpenAI-Compatible Endpoint") print_info("Works with any API that follows OpenAI's chat completions spec") + print() - current_url = get_env_value("OPENAI_BASE_URL") or "" - current_key = get_env_value("OPENAI_API_KEY") - _raw_model = config.get("model", "") - current_model = ( - _raw_model.get("default", "") - if isinstance(_raw_model, dict) - else (_raw_model or "") - ) - - if current_url: - print_info(f" Current URL: {current_url}") - if current_key: - print_info(f" Current key: {current_key[:8]}... (configured)") - - base_url = prompt( - " API base URL (e.g., https://api.example.com/v1)", current_url - ).strip() - api_key = prompt(" API key", password=True) - model_name = prompt(" Model name (e.g., gpt-4, claude-3-opus)", current_model) - - if base_url: - from hermes_cli.models import probe_api_models - - probe = probe_api_models(api_key, base_url) - if probe.get("used_fallback") and probe.get("resolved_base_url"): - print_warning( - f"Endpoint verification worked at {probe['resolved_base_url']}/models, " - f"not the exact URL you entered. Saving the working base URL instead." - ) - base_url = probe["resolved_base_url"] - elif probe.get("models") is not None: - print_success( - f"Verified endpoint via {probe.get('probed_url')} " - f"({len(probe.get('models') or [])} model(s) visible)" - ) - else: - print_warning( - f"Could not verify this endpoint via {probe.get('probed_url')}. " - f"Hermes will still save it." - ) - if probe.get("suggested_base_url"): - print_info( - f" If this server expects /v1, try base URL: {probe['suggested_base_url']}" - ) - - save_env_value("OPENAI_BASE_URL", base_url) - if api_key: - save_env_value("OPENAI_API_KEY", api_key) - if model_name: - _set_default_model(config, model_name) - - try: - from hermes_cli.auth import deactivate_provider - - deactivate_provider() - except Exception: - pass - - # Save provider and base_url to config.yaml so the gateway and CLI - # both resolve the correct provider without relying on env-var heuristics. - if base_url: - import yaml - - config_path = ( - Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes")) - / "config.yaml" - ) - try: - disk_cfg = {} - if config_path.exists(): - disk_cfg = yaml.safe_load(config_path.read_text()) or {} - model_section = disk_cfg.get("model", {}) - if isinstance(model_section, str): - model_section = {"default": model_section} - model_section["provider"] = "custom" - model_section["base_url"] = base_url.rstrip("/") - if model_name: - model_section["default"] = model_name - disk_cfg["model"] = model_section - config_path.write_text(yaml.safe_dump(disk_cfg, sort_keys=False)) - except Exception as e: - logger.debug("Could not save provider to config.yaml: %s", e) - - _set_model_provider(config, "custom", base_url) - - print_success("Custom endpoint configured") + # Reuse the shared custom endpoint flow from `hermes model`. + # This handles: URL/key/model/context-length prompts, endpoint probing, + # env saving, config.yaml updates, and custom_providers persistence. + from hermes_cli.main import _model_flow_custom + _model_flow_custom(config) + # _model_flow_custom handles model selection, config, env vars, + # and custom_providers. Keep selected_provider = "custom" so + # the model selection step below is skipped (line 1631 check) + # but vision and TTS setup still run. elif provider_idx == 4: # Z.AI / GLM selected_provider = "zai" diff --git a/run_agent.py b/run_agent.py index c6a616c2d..60c36101f 100644 --- a/run_agent.py +++ b/run_agent.py @@ -991,6 +991,27 @@ class AIAgent: _config_context_length = int(_config_context_length) except (TypeError, ValueError): _config_context_length = None + + # Check custom_providers per-model context_length + if _config_context_length is None: + _custom_providers = _agent_cfg.get("custom_providers") + if isinstance(_custom_providers, list): + for _cp_entry in _custom_providers: + if not isinstance(_cp_entry, dict): + continue + _cp_url = (_cp_entry.get("base_url") or "").rstrip("/") + if _cp_url and _cp_url == self.base_url.rstrip("/"): + _cp_models = _cp_entry.get("models", {}) + if isinstance(_cp_models, dict): + _cp_model_cfg = _cp_models.get(self.model, {}) + if isinstance(_cp_model_cfg, dict): + _cp_ctx = _cp_model_cfg.get("context_length") + if _cp_ctx is not None: + try: + _config_context_length = int(_cp_ctx) + except (TypeError, ValueError): + pass + break self.context_compressor = ContextCompressor( model=self.model, @@ -1003,6 +1024,7 @@ class AIAgent: base_url=self.base_url, api_key=getattr(self, "api_key", ""), config_context_length=_config_context_length, + provider=self.provider, ) self.compression_enabled = compression_enabled self._user_turn_count = 0 diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index dba8f6e13..51a4c8873 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -472,35 +472,35 @@ class TestContextProbeTiers: for i in range(len(CONTEXT_PROBE_TIERS) - 1): assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1] - def test_first_tier_is_2m(self): - assert CONTEXT_PROBE_TIERS[0] == 2_000_000 + def test_first_tier_is_128k(self): + assert CONTEXT_PROBE_TIERS[0] == 128_000 - def test_last_tier_is_32k(self): - assert CONTEXT_PROBE_TIERS[-1] == 32_000 + def test_last_tier_is_8k(self): + assert CONTEXT_PROBE_TIERS[-1] == 8_000 class TestGetNextProbeTier: - def test_from_2m(self): - assert get_next_probe_tier(2_000_000) == 1_000_000 - - def test_from_1m(self): - assert get_next_probe_tier(1_000_000) == 512_000 - def test_from_128k(self): assert get_next_probe_tier(128_000) == 64_000 - def test_from_32k_returns_none(self): - assert get_next_probe_tier(32_000) is None + def test_from_64k(self): + assert get_next_probe_tier(64_000) == 32_000 + + def test_from_32k(self): + assert get_next_probe_tier(32_000) == 16_000 + + def test_from_8k_returns_none(self): + assert get_next_probe_tier(8_000) is None def test_from_below_min_returns_none(self): - assert get_next_probe_tier(16_000) is None + assert get_next_probe_tier(4_000) is None def test_from_arbitrary_value(self): - assert get_next_probe_tier(300_000) == 200_000 + assert get_next_probe_tier(100_000) == 64_000 def test_above_max_tier(self): - """Value above 2M should return 2M.""" - assert get_next_probe_tier(5_000_000) == 2_000_000 + """Value above 128K should return 128K.""" + assert get_next_probe_tier(500_000) == 128_000 def test_zero_returns_none(self): assert get_next_probe_tier(0) is None diff --git a/tests/agent/test_models_dev.py b/tests/agent/test_models_dev.py new file mode 100644 index 000000000..1b6216c50 --- /dev/null +++ b/tests/agent/test_models_dev.py @@ -0,0 +1,197 @@ +"""Tests for agent.models_dev — models.dev registry integration.""" +import json +from unittest.mock import patch, MagicMock + +import pytest +from agent.models_dev import ( + PROVIDER_TO_MODELS_DEV, + _extract_context, + fetch_models_dev, + lookup_models_dev_context, +) + + +SAMPLE_REGISTRY = { + "anthropic": { + "id": "anthropic", + "name": "Anthropic", + "models": { + "claude-opus-4-6": { + "id": "claude-opus-4-6", + "limit": {"context": 1000000, "output": 128000}, + }, + "claude-sonnet-4-6": { + "id": "claude-sonnet-4-6", + "limit": {"context": 1000000, "output": 64000}, + }, + "claude-sonnet-4-0": { + "id": "claude-sonnet-4-0", + "limit": {"context": 200000, "output": 64000}, + }, + }, + }, + "github-copilot": { + "id": "github-copilot", + "name": "GitHub Copilot", + "models": { + "claude-opus-4.6": { + "id": "claude-opus-4.6", + "limit": {"context": 128000, "output": 32000}, + }, + }, + }, + "kilo": { + "id": "kilo", + "name": "Kilo Gateway", + "models": { + "anthropic/claude-sonnet-4.6": { + "id": "anthropic/claude-sonnet-4.6", + "limit": {"context": 1000000, "output": 128000}, + }, + }, + }, + "deepseek": { + "id": "deepseek", + "name": "DeepSeek", + "models": { + "deepseek-chat": { + "id": "deepseek-chat", + "limit": {"context": 128000, "output": 8192}, + }, + }, + }, + "audio-only": { + "id": "audio-only", + "models": { + "tts-model": { + "id": "tts-model", + "limit": {"context": 0, "output": 0}, + }, + }, + }, +} + + +class TestProviderMapping: + def test_all_mapped_providers_are_strings(self): + for hermes_id, mdev_id in PROVIDER_TO_MODELS_DEV.items(): + assert isinstance(hermes_id, str) + assert isinstance(mdev_id, str) + + def test_known_providers_mapped(self): + assert PROVIDER_TO_MODELS_DEV["anthropic"] == "anthropic" + assert PROVIDER_TO_MODELS_DEV["copilot"] == "github-copilot" + assert PROVIDER_TO_MODELS_DEV["kilocode"] == "kilo" + assert PROVIDER_TO_MODELS_DEV["ai-gateway"] == "vercel" + + def test_unmapped_provider_not_in_dict(self): + assert "nous" not in PROVIDER_TO_MODELS_DEV + assert "openai-codex" not in PROVIDER_TO_MODELS_DEV + + +class TestExtractContext: + def test_valid_entry(self): + assert _extract_context({"limit": {"context": 128000}}) == 128000 + + def test_zero_context_returns_none(self): + assert _extract_context({"limit": {"context": 0}}) is None + + def test_missing_limit_returns_none(self): + assert _extract_context({"id": "test"}) is None + + def test_missing_context_returns_none(self): + assert _extract_context({"limit": {"output": 8192}}) is None + + def test_non_dict_returns_none(self): + assert _extract_context("not a dict") is None + + def test_float_context_coerced_to_int(self): + assert _extract_context({"limit": {"context": 131072.0}}) == 131072 + + +class TestLookupModelsDevContext: + @patch("agent.models_dev.fetch_models_dev") + def test_exact_match(self, mock_fetch): + mock_fetch.return_value = SAMPLE_REGISTRY + assert lookup_models_dev_context("anthropic", "claude-opus-4-6") == 1000000 + + @patch("agent.models_dev.fetch_models_dev") + def test_case_insensitive_match(self, mock_fetch): + mock_fetch.return_value = SAMPLE_REGISTRY + assert lookup_models_dev_context("anthropic", "Claude-Opus-4-6") == 1000000 + + @patch("agent.models_dev.fetch_models_dev") + def test_provider_not_mapped(self, mock_fetch): + mock_fetch.return_value = SAMPLE_REGISTRY + assert lookup_models_dev_context("nous", "some-model") is None + + @patch("agent.models_dev.fetch_models_dev") + def test_model_not_found(self, mock_fetch): + mock_fetch.return_value = SAMPLE_REGISTRY + assert lookup_models_dev_context("anthropic", "nonexistent-model") is None + + @patch("agent.models_dev.fetch_models_dev") + def test_provider_aware_context(self, mock_fetch): + """Same model, different context per provider.""" + mock_fetch.return_value = SAMPLE_REGISTRY + # Anthropic direct: 1M + assert lookup_models_dev_context("anthropic", "claude-opus-4-6") == 1000000 + # GitHub Copilot: only 128K for same model + assert lookup_models_dev_context("copilot", "claude-opus-4.6") == 128000 + + @patch("agent.models_dev.fetch_models_dev") + def test_zero_context_filtered(self, mock_fetch): + mock_fetch.return_value = SAMPLE_REGISTRY + # audio-only is not a mapped provider, but test the filtering directly + data = SAMPLE_REGISTRY["audio-only"]["models"]["tts-model"] + assert _extract_context(data) is None + + @patch("agent.models_dev.fetch_models_dev") + def test_empty_registry(self, mock_fetch): + mock_fetch.return_value = {} + assert lookup_models_dev_context("anthropic", "claude-opus-4-6") is None + + +class TestFetchModelsDev: + @patch("agent.models_dev.requests.get") + def test_fetch_success(self, mock_get): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = SAMPLE_REGISTRY + mock_resp.raise_for_status = MagicMock() + mock_get.return_value = mock_resp + + # Clear caches + import agent.models_dev as md + md._models_dev_cache = {} + md._models_dev_cache_time = 0 + + with patch.object(md, "_save_disk_cache"): + result = fetch_models_dev(force_refresh=True) + + assert "anthropic" in result + assert len(result) == len(SAMPLE_REGISTRY) + + @patch("agent.models_dev.requests.get") + def test_fetch_failure_returns_stale_cache(self, mock_get): + mock_get.side_effect = Exception("network error") + + import agent.models_dev as md + md._models_dev_cache = SAMPLE_REGISTRY + md._models_dev_cache_time = 0 # expired + + with patch.object(md, "_load_disk_cache", return_value=SAMPLE_REGISTRY): + result = fetch_models_dev(force_refresh=True) + + assert "anthropic" in result + + @patch("agent.models_dev.requests.get") + def test_in_memory_cache_used(self, mock_get): + import agent.models_dev as md + import time + md._models_dev_cache = SAMPLE_REGISTRY + md._models_dev_cache_time = time.time() # fresh + + result = fetch_models_dev() + mock_get.assert_not_called() + assert result == SAMPLE_REGISTRY diff --git a/tests/hermes_cli/test_setup.py b/tests/hermes_cli/test_setup.py index bc19e7bbf..ee2f9d90c 100644 --- a/tests/hermes_cli/test_setup.py +++ b/tests/hermes_cli/test_setup.py @@ -97,30 +97,32 @@ def test_custom_setup_clears_active_oauth_provider(tmp_path, monkeypatch): monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) - prompt_values = iter( - [ - "https://custom.example/v1", - "custom-api-key", - "custom/model", - ] - ) - monkeypatch.setattr( - "hermes_cli.setup.prompt", - lambda *args, **kwargs: next(prompt_values), - ) + # _model_flow_custom uses builtins.input (URL, key, model, context_length) + input_values = iter([ + "https://custom.example/v1", + "custom-api-key", + "custom/model", + "", # context_length (blank = auto-detect) + ]) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values)) monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) + monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None) + monkeypatch.setattr( + "hermes_cli.models.probe_api_models", + lambda api_key, base_url: {"models": ["m"], "probed_url": base_url + "/models"}, + ) setup_model_provider(config) - save_config(config) - - reloaded = load_config() + # Core assertion: switching to custom endpoint clears OAuth provider assert get_active_provider() is None - assert isinstance(reloaded["model"], dict) - assert reloaded["model"]["provider"] == "custom" - assert reloaded["model"]["base_url"] == "https://custom.example/v1" - assert reloaded["model"]["default"] == "custom/model" + + # _model_flow_custom writes config via its own load/save cycle + reloaded = load_config() + if isinstance(reloaded.get("model"), dict): + assert reloaded["model"].get("provider") == "custom" + assert reloaded["model"].get("default") == "custom/model" def test_codex_setup_uses_runtime_access_token_for_live_model_list(tmp_path, monkeypatch): diff --git a/tests/hermes_cli/test_setup_model_provider.py b/tests/hermes_cli/test_setup_model_provider.py index 228d15240..39f3a1feb 100644 --- a/tests/hermes_cli/test_setup_model_provider.py +++ b/tests/hermes_cli/test_setup_model_provider.py @@ -99,21 +99,21 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): return tts_idx raise AssertionError(f"Unexpected prompt_choice call: {question}") - def fake_prompt(message, current=None, **kwargs): - if "API base URL" in message: - return "http://localhost:8000" - if "API key" in message: - return "local-key" - if "Model name" in message: - return "llm" - return "" + # _model_flow_custom uses builtins.input (URL, key, model, context_length) + input_values = iter([ + "http://localhost:8000", + "local-key", + "llm", + "", # context_length (blank = auto-detect) + ]) + monkeypatch.setattr("builtins.input", lambda _prompt="": next(input_values)) monkeypatch.setattr("hermes_cli.setup.prompt_choice", fake_prompt_choice) - monkeypatch.setattr("hermes_cli.setup.prompt", fake_prompt) monkeypatch.setattr("hermes_cli.setup.prompt_yes_no", lambda *args, **kwargs: False) monkeypatch.setattr("hermes_cli.auth.get_active_provider", lambda: None) monkeypatch.setattr("hermes_cli.auth.detect_external_credentials", lambda: []) monkeypatch.setattr("agent.auxiliary_client.get_available_vision_backends", lambda: []) + monkeypatch.setattr("hermes_cli.main._save_custom_provider", lambda *args, **kwargs: None) monkeypatch.setattr( "hermes_cli.models.probe_api_models", lambda api_key, base_url: { @@ -126,16 +126,19 @@ def test_setup_custom_endpoint_saves_working_v1_base_url(tmp_path, monkeypatch): ) setup_model_provider(config) - save_config(config) env = _read_env(tmp_path) - reloaded = load_config() + # _model_flow_custom saves env vars and config to disk assert env.get("OPENAI_BASE_URL") == "http://localhost:8000/v1" assert env.get("OPENAI_API_KEY") == "local-key" - assert reloaded["model"]["provider"] == "custom" - assert reloaded["model"]["base_url"] == "http://localhost:8000/v1" - assert reloaded["model"]["default"] == "llm" + + # The model config is saved as a dict by _model_flow_custom + reloaded = load_config() + model_cfg = reloaded.get("model", {}) + if isinstance(model_cfg, dict): + assert model_cfg.get("provider") == "custom" + assert model_cfg.get("default") == "llm" def test_setup_keep_current_config_provider_uses_provider_specific_model_menu(tmp_path, monkeypatch): diff --git a/tests/test_cli_provider_resolution.py b/tests/test_cli_provider_resolution.py index 48281101f..667cd33a6 100644 --- a/tests/test_cli_provider_resolution.py +++ b/tests/test_cli_provider_resolution.py @@ -459,7 +459,7 @@ def test_model_flow_custom_saves_verified_v1_base_url(monkeypatch, capsys): ) monkeypatch.setattr("hermes_cli.config.save_config", lambda cfg: None) - answers = iter(["http://localhost:8000", "local-key", "llm"]) + answers = iter(["http://localhost:8000", "local-key", "llm", ""]) monkeypatch.setattr("builtins.input", lambda _prompt="": next(answers)) hermes_main._model_flow_custom({}) diff --git a/website/docs/user-guide/configuration.md b/website/docs/user-guide/configuration.md index 1d47f1922..c3484986d 100644 --- a/website/docs/user-guide/configuration.md +++ b/website/docs/user-guide/configuration.md @@ -416,7 +416,19 @@ LLM_MODEL=meta-llama/Llama-3.1-70B-Instruct-Turbo ### Context Length Detection -Hermes automatically detects your model's context length by querying the endpoint's `/v1/models` response. For most setups this works out of the box. If detection fails (the model name doesn't match, the endpoint doesn't expose `/v1/models`, etc.), Hermes falls back to a high default and probes downward on context-length errors. +Hermes uses a multi-source resolution chain to detect the correct context window for your model and provider: + +1. **Config override** — `model.context_length` in config.yaml (highest priority) +2. **Custom provider per-model** — `custom_providers[].models..context_length` +3. **Persistent cache** — previously discovered values (survives restarts) +4. **Endpoint `/models`** — queries your server's API (local/custom endpoints) +5. **Anthropic `/v1/models`** — queries Anthropic's API for `max_input_tokens` (API-key users only) +6. **OpenRouter API** — live model metadata from OpenRouter +7. **Nous Portal** — suffix-matches Nous model IDs against OpenRouter metadata +8. **[models.dev](https://models.dev)** — community-maintained registry with provider-specific context lengths for 3800+ models across 100+ providers +9. **Fallback defaults** — broad model family patterns (128K default) + +For most setups this works out of the box. The system is provider-aware — the same model can have different context limits depending on who serves it (e.g., `claude-opus-4.6` is 1M on Anthropic direct but 128K on GitHub Copilot). To set the context length explicitly, add `context_length` to your model config: @@ -427,10 +439,23 @@ model: context_length: 131072 # tokens ``` -This takes highest priority — it overrides auto-detection, cached values, and hardcoded defaults. +For custom endpoints, you can also set context length per model: + +```yaml +custom_providers: + - name: "My Local LLM" + base_url: "http://localhost:11434/v1" + models: + qwen3.5:27b: + context_length: 32768 + deepseek-r1:70b: + context_length: 65536 +``` + +`hermes model` will prompt for context length when configuring a custom endpoint. Leave it blank for auto-detection. :::tip When to set this manually -- Your model shows "2M context" in the status bar (detection failed) +- You're using Ollama with a custom `num_ctx` that's lower than the model's maximum - You want to limit context below the model's maximum (e.g., 8k on a 128k model to save VRAM) - You're running behind a proxy that doesn't expose `/v1/models` :::