The context length resolver was querying the /models endpoint for known providers like GitHub Copilot, which returns a provider-imposed limit (128k) instead of the model's actual context window (400k for gpt-5.4). Since this check happened before the models.dev lookup, the wrong value won every time. Fix: - Add api.githubcopilot.com and models.github.ai to _URL_TO_PROVIDER - Skip the endpoint metadata probe for known providers — their /models data is unreliable for context length. models.dev has the correct per-provider values. Reported by danny [DUMB] — gpt-5.4 via Copilot was resolving to 128k instead of the correct 400k from models.dev.
898 lines
35 KiB
Python
898 lines
35 KiB
Python
"""Model metadata, context lengths, and token estimation utilities.
|
|
|
|
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
|
|
and run_agent.py for pre-flight context checks.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
from hermes_constants import OPENROUTER_MODELS_URL
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Provider names that can appear as a "provider:" prefix before a model ID.
|
|
# Only these are stripped — Ollama-style "model:tag" colons (e.g. "qwen3.5:27b")
|
|
# are preserved so the full model name reaches cache lookups and server queries.
|
|
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
|
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
|
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
|
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
|
"custom", "local",
|
|
# Common aliases
|
|
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
|
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
|
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
|
})
|
|
|
|
|
|
_OLLAMA_TAG_PATTERN = re.compile(
|
|
r"^(\d+\.?\d*b|latest|stable|q\d|fp?\d|instruct|chat|coder|vision|text)",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
|
|
def _strip_provider_prefix(model: str) -> str:
|
|
"""Strip a recognised provider prefix from a model string.
|
|
|
|
``"local:my-model"`` → ``"my-model"``
|
|
``"qwen3.5:27b"`` → ``"qwen3.5:27b"`` (unchanged — not a provider prefix)
|
|
``"qwen:0.5b"`` → ``"qwen:0.5b"`` (unchanged — Ollama model:tag)
|
|
``"deepseek:latest"``→ ``"deepseek:latest"``(unchanged — Ollama model:tag)
|
|
"""
|
|
if ":" not in model or model.startswith("http"):
|
|
return model
|
|
prefix, suffix = model.split(":", 1)
|
|
prefix_lower = prefix.strip().lower()
|
|
if prefix_lower in _PROVIDER_PREFIXES:
|
|
# Don't strip if suffix looks like an Ollama tag (e.g. "7b", "latest", "q4_0")
|
|
if _OLLAMA_TAG_PATTERN.match(suffix.strip()):
|
|
return model
|
|
return suffix
|
|
return model
|
|
|
|
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
|
_model_metadata_cache_time: float = 0
|
|
_MODEL_CACHE_TTL = 3600
|
|
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
|
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
|
_ENDPOINT_MODEL_CACHE_TTL = 300
|
|
|
|
# Descending tiers for context length probing when the model is unknown.
|
|
# We start at 128K (a safe default for most modern models) and step down
|
|
# on context-length errors until one works.
|
|
CONTEXT_PROBE_TIERS = [
|
|
128_000,
|
|
64_000,
|
|
32_000,
|
|
16_000,
|
|
8_000,
|
|
]
|
|
|
|
# Default context length when no detection method succeeds.
|
|
DEFAULT_FALLBACK_CONTEXT = CONTEXT_PROBE_TIERS[0]
|
|
|
|
# Thin fallback defaults — only broad model family patterns.
|
|
# These fire only when provider is unknown AND models.dev/OpenRouter/Anthropic
|
|
# all miss. Replaced the previous 80+ entry dict.
|
|
# For provider-specific context lengths, models.dev is the primary source.
|
|
DEFAULT_CONTEXT_LENGTHS = {
|
|
# Anthropic Claude 4.6 (1M context) — bare IDs only to avoid
|
|
# fuzzy-match collisions (e.g. "anthropic/claude-sonnet-4" is a
|
|
# substring of "anthropic/claude-sonnet-4.6").
|
|
# OpenRouter-prefixed models resolve via OpenRouter live API or models.dev.
|
|
"claude-opus-4-6": 1000000,
|
|
"claude-sonnet-4-6": 1000000,
|
|
"claude-opus-4.6": 1000000,
|
|
"claude-sonnet-4.6": 1000000,
|
|
# Catch-all for older Claude models (must sort after specific entries)
|
|
"claude": 200000,
|
|
# OpenAI
|
|
"gpt-4.1": 1047576,
|
|
"gpt-5": 128000,
|
|
"gpt-4": 128000,
|
|
# Google
|
|
"gemini": 1048576,
|
|
# DeepSeek
|
|
"deepseek": 128000,
|
|
# Meta
|
|
"llama": 131072,
|
|
# Qwen
|
|
"qwen": 131072,
|
|
# MiniMax
|
|
"minimax": 204800,
|
|
# GLM
|
|
"glm": 202752,
|
|
# Kimi
|
|
"kimi": 262144,
|
|
}
|
|
|
|
_CONTEXT_LENGTH_KEYS = (
|
|
"context_length",
|
|
"context_window",
|
|
"max_context_length",
|
|
"max_position_embeddings",
|
|
"max_model_len",
|
|
"max_input_tokens",
|
|
"max_sequence_length",
|
|
"max_seq_len",
|
|
"n_ctx_train",
|
|
"n_ctx",
|
|
)
|
|
|
|
_MAX_COMPLETION_KEYS = (
|
|
"max_completion_tokens",
|
|
"max_output_tokens",
|
|
"max_tokens",
|
|
)
|
|
|
|
# Local server hostnames / address patterns
|
|
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
|
|
|
|
|
def _normalize_base_url(base_url: str) -> str:
|
|
return (base_url or "").strip().rstrip("/")
|
|
|
|
|
|
def _is_openrouter_base_url(base_url: str) -> bool:
|
|
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
|
|
|
|
|
def _is_custom_endpoint(base_url: str) -> bool:
|
|
normalized = _normalize_base_url(base_url)
|
|
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
|
|
|
|
|
_URL_TO_PROVIDER: Dict[str, str] = {
|
|
"api.openai.com": "openai",
|
|
"chatgpt.com": "openai",
|
|
"api.anthropic.com": "anthropic",
|
|
"api.z.ai": "zai",
|
|
"api.moonshot.ai": "kimi-coding",
|
|
"api.kimi.com": "kimi-coding",
|
|
"api.minimax": "minimax",
|
|
"dashscope.aliyuncs.com": "alibaba",
|
|
"dashscope-intl.aliyuncs.com": "alibaba",
|
|
"openrouter.ai": "openrouter",
|
|
"inference-api.nousresearch.com": "nous",
|
|
"api.deepseek.com": "deepseek",
|
|
"api.githubcopilot.com": "copilot",
|
|
"models.github.ai": "copilot",
|
|
}
|
|
|
|
|
|
def _infer_provider_from_url(base_url: str) -> Optional[str]:
|
|
"""Infer the models.dev provider name from a base URL.
|
|
|
|
This allows context length resolution via models.dev for custom endpoints
|
|
like DashScope (Alibaba), Z.AI, Kimi, etc. without requiring the user to
|
|
explicitly set the provider name in config.
|
|
"""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return None
|
|
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
|
host = parsed.netloc.lower() or parsed.path.lower()
|
|
for url_part, provider in _URL_TO_PROVIDER.items():
|
|
if url_part in host:
|
|
return provider
|
|
return None
|
|
|
|
|
|
def _is_known_provider_base_url(base_url: str) -> bool:
|
|
return _infer_provider_from_url(base_url) is not None
|
|
|
|
|
|
def is_local_endpoint(base_url: str) -> bool:
|
|
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return False
|
|
url = normalized if "://" in normalized else f"http://{normalized}"
|
|
try:
|
|
parsed = urlparse(url)
|
|
host = parsed.hostname or ""
|
|
except Exception:
|
|
return False
|
|
if host in _LOCAL_HOSTS:
|
|
return True
|
|
# RFC-1918 private ranges and link-local
|
|
import ipaddress
|
|
try:
|
|
addr = ipaddress.ip_address(host)
|
|
return addr.is_private or addr.is_loopback or addr.is_link_local
|
|
except ValueError:
|
|
pass
|
|
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
|
parts = host.split(".")
|
|
if len(parts) == 4:
|
|
try:
|
|
first, second = int(parts[0]), int(parts[1])
|
|
if first == 10:
|
|
return True
|
|
if first == 172 and 16 <= second <= 31:
|
|
return True
|
|
if first == 192 and second == 168:
|
|
return True
|
|
except ValueError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_local_server_type(base_url: str) -> Optional[str]:
|
|
"""Detect which local server is running at base_url by probing known endpoints.
|
|
|
|
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
|
"""
|
|
import httpx
|
|
|
|
normalized = _normalize_base_url(base_url)
|
|
server_url = normalized
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
with httpx.Client(timeout=2.0) as client:
|
|
# LM Studio exposes /api/v1/models — check first (most specific)
|
|
try:
|
|
r = client.get(f"{server_url}/api/v1/models")
|
|
if r.status_code == 200:
|
|
return "lm-studio"
|
|
except Exception:
|
|
pass
|
|
# Ollama exposes /api/tags and responds with {"models": [...]}
|
|
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
|
# on this path, so we must verify the response contains "models".
|
|
try:
|
|
r = client.get(f"{server_url}/api/tags")
|
|
if r.status_code == 200:
|
|
try:
|
|
data = r.json()
|
|
if "models" in data:
|
|
return "ollama"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
# llama.cpp exposes /v1/props (older builds used /props without the /v1 prefix)
|
|
try:
|
|
r = client.get(f"{server_url}/v1/props")
|
|
if r.status_code != 200:
|
|
r = client.get(f"{server_url}/props") # fallback for older builds
|
|
if r.status_code == 200 and "default_generation_settings" in r.text:
|
|
return "llamacpp"
|
|
except Exception:
|
|
pass
|
|
# vLLM: /version
|
|
try:
|
|
r = client.get(f"{server_url}/version")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
if "version" in data:
|
|
return "vllm"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def _iter_nested_dicts(value: Any):
|
|
if isinstance(value, dict):
|
|
yield value
|
|
for nested in value.values():
|
|
yield from _iter_nested_dicts(nested)
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
yield from _iter_nested_dicts(item)
|
|
|
|
|
|
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
|
try:
|
|
if isinstance(value, bool):
|
|
return None
|
|
if isinstance(value, str):
|
|
value = value.strip().replace(",", "")
|
|
result = int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if minimum <= result <= maximum:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
|
keyset = {key.lower() for key in keys}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
for key, value in mapping.items():
|
|
if str(key).lower() not in keyset:
|
|
continue
|
|
coerced = _coerce_reasonable_int(value)
|
|
if coerced is not None:
|
|
return coerced
|
|
return None
|
|
|
|
|
|
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
|
|
|
|
|
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
|
|
|
|
|
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
alias_map = {
|
|
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
|
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
|
"request": ("request", "request_cost"),
|
|
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
|
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
|
}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
|
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
|
continue
|
|
pricing: Dict[str, Any] = {}
|
|
for target, aliases in alias_map.items():
|
|
for alias in aliases:
|
|
if alias in normalized and normalized[alias] not in (None, ""):
|
|
pricing[target] = normalized[alias]
|
|
break
|
|
if pricing:
|
|
return pricing
|
|
return {}
|
|
|
|
|
|
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
|
cache[model_id] = entry
|
|
if "/" in model_id:
|
|
bare_model = model_id.split("/", 1)[1]
|
|
cache.setdefault(bare_model, entry)
|
|
|
|
|
|
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
|
global _model_metadata_cache, _model_metadata_cache_time
|
|
|
|
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
|
|
return _model_metadata_cache
|
|
|
|
try:
|
|
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
cache = {}
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
entry = {
|
|
"context_length": model.get("context_length", 128000),
|
|
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
|
"name": model.get("name", model_id),
|
|
"pricing": model.get("pricing", {}),
|
|
}
|
|
_add_model_aliases(cache, model_id, entry)
|
|
canonical = model.get("canonical_slug", "")
|
|
if canonical and canonical != model_id:
|
|
_add_model_aliases(cache, canonical, entry)
|
|
|
|
_model_metadata_cache = cache
|
|
_model_metadata_cache_time = time.time()
|
|
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
|
|
return cache
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
|
|
return _model_metadata_cache or {}
|
|
|
|
|
|
def fetch_endpoint_model_metadata(
|
|
base_url: str,
|
|
api_key: str = "",
|
|
force_refresh: bool = False,
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
|
|
|
This is used for explicit custom endpoints where hardcoded global model-name
|
|
defaults are unreliable. Results are cached in memory per base URL.
|
|
"""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized or _is_openrouter_base_url(normalized):
|
|
return {}
|
|
|
|
if not force_refresh:
|
|
cached = _endpoint_model_metadata_cache.get(normalized)
|
|
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
|
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
|
return cached
|
|
|
|
candidates = [normalized]
|
|
if normalized.endswith("/v1"):
|
|
alternate = normalized[:-3].rstrip("/")
|
|
else:
|
|
alternate = normalized + "/v1"
|
|
if alternate and alternate not in candidates:
|
|
candidates.append(alternate)
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
last_error: Optional[Exception] = None
|
|
|
|
for candidate in candidates:
|
|
url = candidate.rstrip("/") + "/models"
|
|
try:
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
cache: Dict[str, Dict[str, Any]] = {}
|
|
for model in payload.get("data", []):
|
|
if not isinstance(model, dict):
|
|
continue
|
|
model_id = model.get("id")
|
|
if not model_id:
|
|
continue
|
|
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
|
context_length = _extract_context_length(model)
|
|
if context_length is not None:
|
|
entry["context_length"] = context_length
|
|
max_completion_tokens = _extract_max_completion_tokens(model)
|
|
if max_completion_tokens is not None:
|
|
entry["max_completion_tokens"] = max_completion_tokens
|
|
pricing = _extract_pricing(model)
|
|
if pricing:
|
|
entry["pricing"] = pricing
|
|
_add_model_aliases(cache, model_id, entry)
|
|
|
|
# If this is a llama.cpp server, query /props for actual allocated context
|
|
is_llamacpp = any(
|
|
m.get("owned_by") == "llamacpp"
|
|
for m in payload.get("data", []) if isinstance(m, dict)
|
|
)
|
|
if is_llamacpp:
|
|
try:
|
|
# Try /v1/props first (current llama.cpp); fall back to /props for older builds
|
|
base = candidate.rstrip("/").replace("/v1", "")
|
|
props_resp = requests.get(base + "/v1/props", headers=headers, timeout=5)
|
|
if not props_resp.ok:
|
|
props_resp = requests.get(base + "/props", headers=headers, timeout=5)
|
|
if props_resp.ok:
|
|
props = props_resp.json()
|
|
gen_settings = props.get("default_generation_settings", {})
|
|
n_ctx = gen_settings.get("n_ctx")
|
|
model_alias = props.get("model_alias", "")
|
|
if n_ctx and model_alias and model_alias in cache:
|
|
cache[model_alias]["context_length"] = n_ctx
|
|
except Exception:
|
|
pass
|
|
|
|
_endpoint_model_metadata_cache[normalized] = cache
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return cache
|
|
except Exception as exc:
|
|
last_error = exc
|
|
|
|
if last_error:
|
|
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
|
_endpoint_model_metadata_cache[normalized] = {}
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return {}
|
|
|
|
|
|
def _get_context_cache_path() -> Path:
|
|
"""Return path to the persistent context length cache file."""
|
|
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
|
return hermes_home / "context_length_cache.yaml"
|
|
|
|
|
|
def _load_context_cache() -> Dict[str, int]:
|
|
"""Load the model+provider -> context_length cache from disk."""
|
|
path = _get_context_cache_path()
|
|
if not path.exists():
|
|
return {}
|
|
try:
|
|
with open(path) as f:
|
|
data = yaml.safe_load(f) or {}
|
|
return data.get("context_lengths", {})
|
|
except Exception as e:
|
|
logger.debug("Failed to load context length cache: %s", e)
|
|
return {}
|
|
|
|
|
|
def save_context_length(model: str, base_url: str, length: int) -> None:
|
|
"""Persist a discovered context length for a model+provider combo.
|
|
|
|
Cache key is ``model@base_url`` so the same model name served from
|
|
different providers can have different limits.
|
|
"""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
if cache.get(key) == length:
|
|
return # already stored
|
|
cache[key] = length
|
|
path = _get_context_cache_path()
|
|
try:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w") as f:
|
|
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
|
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
|
except Exception as e:
|
|
logger.debug("Failed to save context length cache: %s", e)
|
|
|
|
|
|
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Look up a previously discovered context length for model+provider."""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
return cache.get(key)
|
|
|
|
|
|
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
|
"""Return the next lower probe tier, or None if already at minimum."""
|
|
for tier in CONTEXT_PROBE_TIERS:
|
|
if tier < current_length:
|
|
return tier
|
|
return None
|
|
|
|
|
|
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
|
"""Try to extract the actual context limit from an API error message.
|
|
|
|
Many providers include the limit in their error text, e.g.:
|
|
- "maximum context length is 32768 tokens"
|
|
- "context_length_exceeded: 131072"
|
|
- "Maximum context size 32768 exceeded"
|
|
- "model's max context length is 65536"
|
|
"""
|
|
error_lower = error_msg.lower()
|
|
# Pattern: look for numbers near context-related keywords
|
|
patterns = [
|
|
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
|
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
|
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
|
]
|
|
for pattern in patterns:
|
|
match = re.search(pattern, error_lower)
|
|
if match:
|
|
limit = int(match.group(1))
|
|
# Sanity check: must be a reasonable context length
|
|
if 1024 <= limit <= 10_000_000:
|
|
return limit
|
|
return None
|
|
|
|
|
|
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
|
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
|
|
|
Supports two forms:
|
|
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
|
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
|
(the part after the last "/" equals lookup_model)
|
|
|
|
This covers LM Studio's native API which stores models as "publisher/slug"
|
|
while users typically configure only the slug after the "local:" prefix.
|
|
"""
|
|
if candidate_id == lookup_model:
|
|
return True
|
|
# Slug match: basename of candidate equals the lookup name
|
|
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Query a local server for the model's context length."""
|
|
import httpx
|
|
|
|
# Strip recognised provider prefix (e.g., "local:model-name" → "model-name").
|
|
# Ollama "model:tag" colons (e.g. "qwen3.5:27b") are intentionally preserved.
|
|
model = _strip_provider_prefix(model)
|
|
|
|
# Strip /v1 suffix to get the server root
|
|
server_url = base_url.rstrip("/")
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
server_type = detect_local_server_type(base_url)
|
|
except Exception:
|
|
server_type = None
|
|
|
|
try:
|
|
with httpx.Client(timeout=3.0) as client:
|
|
# Ollama: /api/show returns model details with context info
|
|
if server_type == "ollama":
|
|
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# Check model_info for context length
|
|
model_info = data.get("model_info", {})
|
|
for key, value in model_info.items():
|
|
if "context_length" in key and isinstance(value, (int, float)):
|
|
return int(value)
|
|
# Check parameters string for num_ctx
|
|
params = data.get("parameters", "")
|
|
if "num_ctx" in params:
|
|
for line in params.split("\n"):
|
|
if "num_ctx" in line:
|
|
parts = line.strip().split()
|
|
if len(parts) >= 2:
|
|
try:
|
|
return int(parts[-1])
|
|
except ValueError:
|
|
pass
|
|
|
|
# LM Studio native API: /api/v1/models returns max_context_length.
|
|
# This is more reliable than the OpenAI-compat /v1/models which
|
|
# doesn't include context window information for LM Studio servers.
|
|
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
|
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
|
if server_type == "lm-studio":
|
|
resp = client.get(f"{server_url}/api/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
for m in data.get("models", []):
|
|
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
|
# Prefer loaded instance context (actual runtime value)
|
|
for inst in m.get("loaded_instances", []):
|
|
cfg = inst.get("config", {})
|
|
ctx = cfg.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
# Fall back to max_context_length (theoretical model max)
|
|
ctx = m.get("max_context_length") or m.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
|
resp = client.get(f"{server_url}/v1/models/{model}")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# vLLM returns max_model_len
|
|
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# Try /v1/models and find the model in the list.
|
|
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
|
resp = client.get(f"{server_url}/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
models_list = data.get("data", [])
|
|
for m in models_list:
|
|
if _model_id_matches(m.get("id", ""), model):
|
|
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def _normalize_model_version(model: str) -> str:
|
|
"""Normalize version separators for matching.
|
|
|
|
Nous uses dashes: claude-opus-4-6, claude-sonnet-4-5
|
|
OpenRouter uses dots: claude-opus-4.6, claude-sonnet-4.5
|
|
Normalize both to dashes for comparison.
|
|
"""
|
|
return model.replace(".", "-")
|
|
|
|
|
|
def _query_anthropic_context_length(model: str, base_url: str, api_key: str) -> Optional[int]:
|
|
"""Query Anthropic's /v1/models endpoint for context length.
|
|
|
|
Only works with regular ANTHROPIC_API_KEY (sk-ant-api*).
|
|
OAuth tokens (sk-ant-oat*) from Claude Code return 401.
|
|
"""
|
|
if not api_key or api_key.startswith("sk-ant-oat"):
|
|
return None # OAuth tokens can't access /v1/models
|
|
try:
|
|
base = base_url.rstrip("/")
|
|
if base.endswith("/v1"):
|
|
base = base[:-3]
|
|
url = f"{base}/v1/models?limit=1000"
|
|
headers = {
|
|
"x-api-key": api_key,
|
|
"anthropic-version": "2023-06-01",
|
|
}
|
|
resp = requests.get(url, headers=headers, timeout=10)
|
|
if resp.status_code != 200:
|
|
return None
|
|
data = resp.json()
|
|
for m in data.get("data", []):
|
|
if m.get("id") == model:
|
|
ctx = m.get("max_input_tokens")
|
|
if isinstance(ctx, int) and ctx > 0:
|
|
return ctx
|
|
except Exception as e:
|
|
logger.debug("Anthropic /v1/models query failed: %s", e)
|
|
return None
|
|
|
|
|
|
def _resolve_nous_context_length(model: str) -> Optional[int]:
|
|
"""Resolve Nous Portal model context length via OpenRouter metadata.
|
|
|
|
Nous model IDs are bare (e.g. 'claude-opus-4-6') while OpenRouter uses
|
|
prefixed IDs (e.g. 'anthropic/claude-opus-4.6'). Try suffix matching
|
|
with version normalization (dot↔dash).
|
|
"""
|
|
metadata = fetch_model_metadata() # OpenRouter cache
|
|
# Exact match first
|
|
if model in metadata:
|
|
return metadata[model].get("context_length")
|
|
|
|
normalized = _normalize_model_version(model).lower()
|
|
|
|
for or_id, entry in metadata.items():
|
|
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
|
if bare.lower() == model.lower() or _normalize_model_version(bare).lower() == normalized:
|
|
return entry.get("context_length")
|
|
|
|
# Partial prefix match for cases like gemini-3-flash → gemini-3-flash-preview
|
|
# Require match to be at a word boundary (followed by -, :, or end of string)
|
|
model_lower = model.lower()
|
|
for or_id, entry in metadata.items():
|
|
bare = or_id.split("/", 1)[1] if "/" in or_id else or_id
|
|
for candidate, query in [(bare.lower(), model_lower), (_normalize_model_version(bare).lower(), normalized)]:
|
|
if candidate.startswith(query) and (
|
|
len(candidate) == len(query) or candidate[len(query)] in "-:."
|
|
):
|
|
return entry.get("context_length")
|
|
|
|
return None
|
|
|
|
|
|
def get_model_context_length(
|
|
model: str,
|
|
base_url: str = "",
|
|
api_key: str = "",
|
|
config_context_length: int | None = None,
|
|
provider: str = "",
|
|
) -> int:
|
|
"""Get the context length for a model.
|
|
|
|
Resolution order:
|
|
0. Explicit config override (model.context_length or custom_providers per-model)
|
|
1. Persistent cache (previously discovered via probing)
|
|
2. Active endpoint metadata (/models for explicit custom endpoints)
|
|
3. Local server query (for local endpoints)
|
|
4. Anthropic /v1/models API (API-key users only, not OAuth)
|
|
5. OpenRouter live API metadata
|
|
6. Nous suffix-match via OpenRouter cache
|
|
7. models.dev registry lookup (provider-aware)
|
|
8. Thin hardcoded defaults (broad family patterns)
|
|
9. Default fallback (128K)
|
|
"""
|
|
# 0. Explicit config override — user knows best
|
|
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
|
return config_context_length
|
|
|
|
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
|
# "model-name") so cache lookups and server queries use the bare ID that
|
|
# local servers actually know about. Ollama "model:tag" colons are preserved.
|
|
model = _strip_provider_prefix(model)
|
|
|
|
# 1. Check persistent cache (model+provider)
|
|
if base_url:
|
|
cached = get_cached_context_length(model, base_url)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# 2. Active endpoint metadata for truly custom/unknown endpoints.
|
|
# Known providers (Copilot, OpenAI, Anthropic, etc.) skip this — their
|
|
# /models endpoint may report a provider-imposed limit (e.g. Copilot
|
|
# returns 128k) instead of the model's full context (400k). models.dev
|
|
# has the correct per-provider values and is checked at step 5+.
|
|
if _is_custom_endpoint(base_url) and not _is_known_provider_base_url(base_url):
|
|
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
|
matched = endpoint_metadata.get(model)
|
|
if not matched:
|
|
# Single-model servers: if only one model is loaded, use it
|
|
if len(endpoint_metadata) == 1:
|
|
matched = next(iter(endpoint_metadata.values()))
|
|
else:
|
|
# Fuzzy match: substring in either direction
|
|
for key, entry in endpoint_metadata.items():
|
|
if model in key or key in model:
|
|
matched = entry
|
|
break
|
|
if matched:
|
|
context_length = matched.get("context_length")
|
|
if isinstance(context_length, int):
|
|
return context_length
|
|
if not _is_known_provider_base_url(base_url):
|
|
# 3. Try querying local server directly
|
|
if is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
logger.info(
|
|
"Could not detect context length for model %r at %s — "
|
|
"defaulting to %s tokens (probe-down). Set model.context_length "
|
|
"in config.yaml to override.",
|
|
model, base_url, f"{DEFAULT_FALLBACK_CONTEXT:,}",
|
|
)
|
|
return DEFAULT_FALLBACK_CONTEXT
|
|
|
|
# 4. Anthropic /v1/models API (only for regular API keys, not OAuth)
|
|
if provider == "anthropic" or (
|
|
base_url and "api.anthropic.com" in base_url
|
|
):
|
|
ctx = _query_anthropic_context_length(model, base_url or "https://api.anthropic.com", api_key)
|
|
if ctx:
|
|
return ctx
|
|
|
|
# 5. Provider-aware lookups (before generic OpenRouter cache)
|
|
# These are provider-specific and take priority over the generic OR cache,
|
|
# since the same model can have different context limits per provider
|
|
# (e.g. claude-opus-4.6 is 1M on Anthropic but 128K on GitHub Copilot).
|
|
# If provider is generic (openrouter/custom/empty), try to infer from URL.
|
|
effective_provider = provider
|
|
if not effective_provider or effective_provider in ("openrouter", "custom"):
|
|
if base_url:
|
|
inferred = _infer_provider_from_url(base_url)
|
|
if inferred:
|
|
effective_provider = inferred
|
|
|
|
if effective_provider == "nous":
|
|
ctx = _resolve_nous_context_length(model)
|
|
if ctx:
|
|
return ctx
|
|
if effective_provider:
|
|
from agent.models_dev import lookup_models_dev_context
|
|
ctx = lookup_models_dev_context(effective_provider, model)
|
|
if ctx:
|
|
return ctx
|
|
|
|
# 6. OpenRouter live API metadata (provider-unaware fallback)
|
|
metadata = fetch_model_metadata()
|
|
if model in metadata:
|
|
return metadata[model].get("context_length", 128000)
|
|
|
|
# 8. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
|
# Only check `default_model in model` (is the key a substring of the input).
|
|
# The reverse (`model in default_model`) causes shorter names like
|
|
# "claude-sonnet-4" to incorrectly match "claude-sonnet-4-6" and return 1M.
|
|
model_lower = model.lower()
|
|
for default_model, length in sorted(
|
|
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
|
):
|
|
if default_model in model_lower:
|
|
return length
|
|
|
|
# 9. Query local server as last resort
|
|
if base_url and is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
|
|
# 10. Default fallback — 128K
|
|
return DEFAULT_FALLBACK_CONTEXT
|
|
|
|
|
|
def estimate_tokens_rough(text: str) -> int:
|
|
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // 4
|
|
|
|
|
|
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
|
"""Rough token estimate for a message list (pre-flight only)."""
|
|
total_chars = sum(len(str(msg)) for msg in messages)
|
|
return total_chars // 4
|