* fix: preserve Ollama model:tag colons in context length detection The colon-split logic in get_model_context_length() and _query_local_context_length() assumed any colon meant provider:model format (e.g. "local:my-model"). But Ollama uses model:tag format (e.g. "qwen3.5:27b"), so the split turned "qwen3.5:27b" into just "27b" — which matches nothing, causing a fallback to the 2M token probe tier. Now only recognised provider prefixes (local, openrouter, anthropic, etc.) are stripped. Ollama model:tag names pass through intact. * fix: update claude-opus-4-6 and claude-sonnet-4-6 context length from 200K to 1M Both models support 1,000,000 token context windows. The hardcoded defaults were set before Anthropic expanded the context for the 4.6 generation. Verified via models.dev and OpenRouter API data. --------- Co-authored-by: kshitijk4poor <82637225+kshitijk4poor@users.noreply.github.com> Co-authored-by: Test <test@test.com>
797 lines
30 KiB
Python
797 lines
30 KiB
Python
"""Model metadata, context lengths, and token estimation utilities.
|
|
|
|
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
|
|
and run_agent.py for pre-flight context checks.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
from hermes_constants import OPENROUTER_MODELS_URL
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Provider names that can appear as a "provider:" prefix before a model ID.
|
|
# Only these are stripped — Ollama-style "model:tag" colons (e.g. "qwen3.5:27b")
|
|
# are preserved so the full model name reaches cache lookups and server queries.
|
|
_PROVIDER_PREFIXES: frozenset[str] = frozenset({
|
|
"openrouter", "nous", "openai-codex", "copilot", "copilot-acp",
|
|
"zai", "kimi-coding", "minimax", "minimax-cn", "anthropic", "deepseek",
|
|
"opencode-zen", "opencode-go", "ai-gateway", "kilocode", "alibaba",
|
|
"custom", "local",
|
|
# Common aliases
|
|
"glm", "z-ai", "z.ai", "zhipu", "github", "github-copilot",
|
|
"github-models", "kimi", "moonshot", "claude", "deep-seek",
|
|
"opencode", "zen", "go", "vercel", "kilo", "dashscope", "aliyun", "qwen",
|
|
})
|
|
|
|
|
|
def _strip_provider_prefix(model: str) -> str:
|
|
"""Strip a recognised provider prefix from a model string.
|
|
|
|
``"local:my-model"`` → ``"my-model"``
|
|
``"qwen3.5:27b"`` → ``"qwen3.5:27b"`` (unchanged — not a provider prefix)
|
|
"""
|
|
if ":" not in model or model.startswith("http"):
|
|
return model
|
|
prefix = model.split(":", 1)[0].strip().lower()
|
|
if prefix in _PROVIDER_PREFIXES:
|
|
return model.split(":", 1)[1]
|
|
return model
|
|
|
|
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
|
_model_metadata_cache_time: float = 0
|
|
_MODEL_CACHE_TTL = 3600
|
|
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
|
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
|
_ENDPOINT_MODEL_CACHE_TTL = 300
|
|
|
|
# Descending tiers for context length probing when the model is unknown.
|
|
# We start high and step down on context-length errors until one works.
|
|
CONTEXT_PROBE_TIERS = [
|
|
2_000_000,
|
|
1_000_000,
|
|
512_000,
|
|
200_000,
|
|
128_000,
|
|
64_000,
|
|
32_000,
|
|
]
|
|
|
|
DEFAULT_CONTEXT_LENGTHS = {
|
|
"anthropic/claude-opus-4": 200000,
|
|
"anthropic/claude-opus-4.5": 200000,
|
|
"anthropic/claude-opus-4.6": 1000000,
|
|
"anthropic/claude-sonnet-4": 200000,
|
|
"anthropic/claude-sonnet-4-20250514": 200000,
|
|
"anthropic/claude-sonnet-4.5": 200000,
|
|
"anthropic/claude-sonnet-4.6": 1000000,
|
|
"anthropic/claude-haiku-4.5": 200000,
|
|
# Bare Anthropic model IDs (for native API provider)
|
|
"claude-opus-4-6": 1000000,
|
|
"claude-sonnet-4-6": 1000000,
|
|
"claude-opus-4-5-20251101": 200000,
|
|
"claude-sonnet-4-5-20250929": 200000,
|
|
"claude-opus-4-1-20250805": 200000,
|
|
"claude-opus-4-20250514": 200000,
|
|
"claude-sonnet-4-20250514": 200000,
|
|
"claude-haiku-4-5-20251001": 200000,
|
|
"openai/gpt-5": 128000,
|
|
"openai/gpt-4.1": 1047576,
|
|
"openai/gpt-4.1-mini": 1047576,
|
|
"openai/gpt-4o": 128000,
|
|
"openai/gpt-4-turbo": 128000,
|
|
"openai/gpt-4o-mini": 128000,
|
|
"google/gemini-3-pro-preview": 1048576,
|
|
"google/gemini-3-flash": 1048576,
|
|
"google/gemini-2.5-flash": 1048576,
|
|
"google/gemini-2.0-flash": 1048576,
|
|
"google/gemini-2.5-pro": 1048576,
|
|
"deepseek/deepseek-v3.2": 65536,
|
|
"meta-llama/llama-3.3-70b-instruct": 131072,
|
|
"deepseek/deepseek-chat-v3": 65536,
|
|
"qwen/qwen-2.5-72b-instruct": 32768,
|
|
"glm-4.7": 202752,
|
|
"glm-5": 202752,
|
|
"glm-4.5": 131072,
|
|
"glm-4.5-flash": 131072,
|
|
"kimi-for-coding": 262144,
|
|
"kimi-k2.5": 262144,
|
|
"kimi-k2-thinking": 262144,
|
|
"kimi-k2-thinking-turbo": 262144,
|
|
"kimi-k2-turbo-preview": 262144,
|
|
"kimi-k2-0905-preview": 131072,
|
|
"MiniMax-M2.7": 204800,
|
|
"MiniMax-M2.7-highspeed": 204800,
|
|
"MiniMax-M2.5": 204800,
|
|
"MiniMax-M2.5-highspeed": 204800,
|
|
"MiniMax-M2.1": 204800,
|
|
# OpenCode Zen models
|
|
"gpt-5.4-pro": 128000,
|
|
"gpt-5.4": 128000,
|
|
"gpt-5.3-codex": 128000,
|
|
"gpt-5.3-codex-spark": 128000,
|
|
"gpt-5.2": 128000,
|
|
"gpt-5.2-codex": 128000,
|
|
"gpt-5.1": 128000,
|
|
"gpt-5.1-codex": 128000,
|
|
"gpt-5.1-codex-max": 128000,
|
|
"gpt-5.1-codex-mini": 128000,
|
|
"gpt-5": 128000,
|
|
"gpt-5-codex": 128000,
|
|
"gpt-5-nano": 128000,
|
|
# Bare model IDs without provider prefix (avoid duplicates with entries above)
|
|
"claude-opus-4-5": 200000,
|
|
"claude-opus-4-1": 200000,
|
|
"claude-sonnet-4-5": 200000,
|
|
"claude-sonnet-4": 200000,
|
|
"claude-haiku-4-5": 200000,
|
|
"claude-3-5-haiku": 200000,
|
|
"gemini-3.1-pro": 1048576,
|
|
"gemini-3-pro": 1048576,
|
|
"gemini-3-flash": 1048576,
|
|
"minimax-m2.5": 204800,
|
|
"minimax-m2.5-free": 204800,
|
|
"minimax-m2.1": 204800,
|
|
"glm-4.6": 202752,
|
|
"kimi-k2": 262144,
|
|
"qwen3-coder": 32768,
|
|
"big-pickle": 128000,
|
|
# Alibaba Cloud / DashScope Qwen models
|
|
"qwen3.5-plus": 131072,
|
|
"qwen3-max": 131072,
|
|
"qwen3-coder-plus": 131072,
|
|
"qwen3-coder-next": 131072,
|
|
"qwen-plus-latest": 131072,
|
|
"qwen3.5-flash": 131072,
|
|
"qwen-vl-max": 32768,
|
|
}
|
|
|
|
_CONTEXT_LENGTH_KEYS = (
|
|
"context_length",
|
|
"context_window",
|
|
"max_context_length",
|
|
"max_position_embeddings",
|
|
"max_model_len",
|
|
"max_input_tokens",
|
|
"max_sequence_length",
|
|
"max_seq_len",
|
|
"n_ctx_train",
|
|
"n_ctx",
|
|
)
|
|
|
|
_MAX_COMPLETION_KEYS = (
|
|
"max_completion_tokens",
|
|
"max_output_tokens",
|
|
"max_tokens",
|
|
)
|
|
|
|
# Local server hostnames / address patterns
|
|
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
|
|
|
|
|
def _normalize_base_url(base_url: str) -> str:
|
|
return (base_url or "").strip().rstrip("/")
|
|
|
|
|
|
def _is_openrouter_base_url(base_url: str) -> bool:
|
|
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
|
|
|
|
|
def _is_custom_endpoint(base_url: str) -> bool:
|
|
normalized = _normalize_base_url(base_url)
|
|
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
|
|
|
|
|
def _is_known_provider_base_url(base_url: str) -> bool:
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return False
|
|
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
|
host = parsed.netloc.lower() or parsed.path.lower()
|
|
known_hosts = (
|
|
"api.openai.com",
|
|
"chatgpt.com",
|
|
"api.anthropic.com",
|
|
"api.z.ai",
|
|
"api.moonshot.ai",
|
|
"api.kimi.com",
|
|
"api.minimax",
|
|
)
|
|
return any(known_host in host for known_host in known_hosts)
|
|
|
|
|
|
def is_local_endpoint(base_url: str) -> bool:
|
|
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return False
|
|
url = normalized if "://" in normalized else f"http://{normalized}"
|
|
try:
|
|
parsed = urlparse(url)
|
|
host = parsed.hostname or ""
|
|
except Exception:
|
|
return False
|
|
if host in _LOCAL_HOSTS:
|
|
return True
|
|
# RFC-1918 private ranges and link-local
|
|
import ipaddress
|
|
try:
|
|
addr = ipaddress.ip_address(host)
|
|
return addr.is_private or addr.is_loopback or addr.is_link_local
|
|
except ValueError:
|
|
pass
|
|
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
|
parts = host.split(".")
|
|
if len(parts) == 4:
|
|
try:
|
|
first, second = int(parts[0]), int(parts[1])
|
|
if first == 10:
|
|
return True
|
|
if first == 172 and 16 <= second <= 31:
|
|
return True
|
|
if first == 192 and second == 168:
|
|
return True
|
|
except ValueError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_local_server_type(base_url: str) -> Optional[str]:
|
|
"""Detect which local server is running at base_url by probing known endpoints.
|
|
|
|
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
|
"""
|
|
import httpx
|
|
|
|
normalized = _normalize_base_url(base_url)
|
|
server_url = normalized
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
with httpx.Client(timeout=2.0) as client:
|
|
# LM Studio exposes /api/v1/models — check first (most specific)
|
|
try:
|
|
r = client.get(f"{server_url}/api/v1/models")
|
|
if r.status_code == 200:
|
|
return "lm-studio"
|
|
except Exception:
|
|
pass
|
|
# Ollama exposes /api/tags and responds with {"models": [...]}
|
|
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
|
# on this path, so we must verify the response contains "models".
|
|
try:
|
|
r = client.get(f"{server_url}/api/tags")
|
|
if r.status_code == 200:
|
|
try:
|
|
data = r.json()
|
|
if "models" in data:
|
|
return "ollama"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
# llama.cpp exposes /props
|
|
try:
|
|
r = client.get(f"{server_url}/props")
|
|
if r.status_code == 200 and "default_generation_settings" in r.text:
|
|
return "llamacpp"
|
|
except Exception:
|
|
pass
|
|
# vLLM: /version
|
|
try:
|
|
r = client.get(f"{server_url}/version")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
if "version" in data:
|
|
return "vllm"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def _iter_nested_dicts(value: Any):
|
|
if isinstance(value, dict):
|
|
yield value
|
|
for nested in value.values():
|
|
yield from _iter_nested_dicts(nested)
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
yield from _iter_nested_dicts(item)
|
|
|
|
|
|
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
|
try:
|
|
if isinstance(value, bool):
|
|
return None
|
|
if isinstance(value, str):
|
|
value = value.strip().replace(",", "")
|
|
result = int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if minimum <= result <= maximum:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
|
keyset = {key.lower() for key in keys}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
for key, value in mapping.items():
|
|
if str(key).lower() not in keyset:
|
|
continue
|
|
coerced = _coerce_reasonable_int(value)
|
|
if coerced is not None:
|
|
return coerced
|
|
return None
|
|
|
|
|
|
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
|
|
|
|
|
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
|
|
|
|
|
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
alias_map = {
|
|
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
|
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
|
"request": ("request", "request_cost"),
|
|
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
|
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
|
}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
|
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
|
continue
|
|
pricing: Dict[str, Any] = {}
|
|
for target, aliases in alias_map.items():
|
|
for alias in aliases:
|
|
if alias in normalized and normalized[alias] not in (None, ""):
|
|
pricing[target] = normalized[alias]
|
|
break
|
|
if pricing:
|
|
return pricing
|
|
return {}
|
|
|
|
|
|
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
|
cache[model_id] = entry
|
|
if "/" in model_id:
|
|
bare_model = model_id.split("/", 1)[1]
|
|
cache.setdefault(bare_model, entry)
|
|
|
|
|
|
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
|
global _model_metadata_cache, _model_metadata_cache_time
|
|
|
|
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
|
|
return _model_metadata_cache
|
|
|
|
try:
|
|
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
cache = {}
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
entry = {
|
|
"context_length": model.get("context_length", 128000),
|
|
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
|
"name": model.get("name", model_id),
|
|
"pricing": model.get("pricing", {}),
|
|
}
|
|
_add_model_aliases(cache, model_id, entry)
|
|
canonical = model.get("canonical_slug", "")
|
|
if canonical and canonical != model_id:
|
|
_add_model_aliases(cache, canonical, entry)
|
|
|
|
_model_metadata_cache = cache
|
|
_model_metadata_cache_time = time.time()
|
|
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
|
|
return cache
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
|
|
return _model_metadata_cache or {}
|
|
|
|
|
|
def fetch_endpoint_model_metadata(
|
|
base_url: str,
|
|
api_key: str = "",
|
|
force_refresh: bool = False,
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
|
|
|
This is used for explicit custom endpoints where hardcoded global model-name
|
|
defaults are unreliable. Results are cached in memory per base URL.
|
|
"""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized or _is_openrouter_base_url(normalized):
|
|
return {}
|
|
|
|
if not force_refresh:
|
|
cached = _endpoint_model_metadata_cache.get(normalized)
|
|
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
|
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
|
return cached
|
|
|
|
candidates = [normalized]
|
|
if normalized.endswith("/v1"):
|
|
alternate = normalized[:-3].rstrip("/")
|
|
else:
|
|
alternate = normalized + "/v1"
|
|
if alternate and alternate not in candidates:
|
|
candidates.append(alternate)
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
last_error: Optional[Exception] = None
|
|
|
|
for candidate in candidates:
|
|
url = candidate.rstrip("/") + "/models"
|
|
try:
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
cache: Dict[str, Dict[str, Any]] = {}
|
|
for model in payload.get("data", []):
|
|
if not isinstance(model, dict):
|
|
continue
|
|
model_id = model.get("id")
|
|
if not model_id:
|
|
continue
|
|
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
|
context_length = _extract_context_length(model)
|
|
if context_length is not None:
|
|
entry["context_length"] = context_length
|
|
max_completion_tokens = _extract_max_completion_tokens(model)
|
|
if max_completion_tokens is not None:
|
|
entry["max_completion_tokens"] = max_completion_tokens
|
|
pricing = _extract_pricing(model)
|
|
if pricing:
|
|
entry["pricing"] = pricing
|
|
_add_model_aliases(cache, model_id, entry)
|
|
|
|
# If this is a llama.cpp server, query /props for actual allocated context
|
|
is_llamacpp = any(
|
|
m.get("owned_by") == "llamacpp"
|
|
for m in payload.get("data", []) if isinstance(m, dict)
|
|
)
|
|
if is_llamacpp:
|
|
try:
|
|
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
|
|
props_resp = requests.get(props_url, headers=headers, timeout=5)
|
|
if props_resp.ok:
|
|
props = props_resp.json()
|
|
gen_settings = props.get("default_generation_settings", {})
|
|
n_ctx = gen_settings.get("n_ctx")
|
|
model_alias = props.get("model_alias", "")
|
|
if n_ctx and model_alias and model_alias in cache:
|
|
cache[model_alias]["context_length"] = n_ctx
|
|
except Exception:
|
|
pass
|
|
|
|
_endpoint_model_metadata_cache[normalized] = cache
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return cache
|
|
except Exception as exc:
|
|
last_error = exc
|
|
|
|
if last_error:
|
|
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
|
_endpoint_model_metadata_cache[normalized] = {}
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return {}
|
|
|
|
|
|
def _get_context_cache_path() -> Path:
|
|
"""Return path to the persistent context length cache file."""
|
|
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
|
return hermes_home / "context_length_cache.yaml"
|
|
|
|
|
|
def _load_context_cache() -> Dict[str, int]:
|
|
"""Load the model+provider -> context_length cache from disk."""
|
|
path = _get_context_cache_path()
|
|
if not path.exists():
|
|
return {}
|
|
try:
|
|
with open(path) as f:
|
|
data = yaml.safe_load(f) or {}
|
|
return data.get("context_lengths", {})
|
|
except Exception as e:
|
|
logger.debug("Failed to load context length cache: %s", e)
|
|
return {}
|
|
|
|
|
|
def save_context_length(model: str, base_url: str, length: int) -> None:
|
|
"""Persist a discovered context length for a model+provider combo.
|
|
|
|
Cache key is ``model@base_url`` so the same model name served from
|
|
different providers can have different limits.
|
|
"""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
if cache.get(key) == length:
|
|
return # already stored
|
|
cache[key] = length
|
|
path = _get_context_cache_path()
|
|
try:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w") as f:
|
|
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
|
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
|
except Exception as e:
|
|
logger.debug("Failed to save context length cache: %s", e)
|
|
|
|
|
|
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Look up a previously discovered context length for model+provider."""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
return cache.get(key)
|
|
|
|
|
|
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
|
"""Return the next lower probe tier, or None if already at minimum."""
|
|
for tier in CONTEXT_PROBE_TIERS:
|
|
if tier < current_length:
|
|
return tier
|
|
return None
|
|
|
|
|
|
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
|
"""Try to extract the actual context limit from an API error message.
|
|
|
|
Many providers include the limit in their error text, e.g.:
|
|
- "maximum context length is 32768 tokens"
|
|
- "context_length_exceeded: 131072"
|
|
- "Maximum context size 32768 exceeded"
|
|
- "model's max context length is 65536"
|
|
"""
|
|
error_lower = error_msg.lower()
|
|
# Pattern: look for numbers near context-related keywords
|
|
patterns = [
|
|
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
|
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
|
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
|
]
|
|
for pattern in patterns:
|
|
match = re.search(pattern, error_lower)
|
|
if match:
|
|
limit = int(match.group(1))
|
|
# Sanity check: must be a reasonable context length
|
|
if 1024 <= limit <= 10_000_000:
|
|
return limit
|
|
return None
|
|
|
|
|
|
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
|
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
|
|
|
Supports two forms:
|
|
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
|
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
|
(the part after the last "/" equals lookup_model)
|
|
|
|
This covers LM Studio's native API which stores models as "publisher/slug"
|
|
while users typically configure only the slug after the "local:" prefix.
|
|
"""
|
|
if candidate_id == lookup_model:
|
|
return True
|
|
# Slug match: basename of candidate equals the lookup name
|
|
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Query a local server for the model's context length."""
|
|
import httpx
|
|
|
|
# Strip recognised provider prefix (e.g., "local:model-name" → "model-name").
|
|
# Ollama "model:tag" colons (e.g. "qwen3.5:27b") are intentionally preserved.
|
|
model = _strip_provider_prefix(model)
|
|
|
|
# Strip /v1 suffix to get the server root
|
|
server_url = base_url.rstrip("/")
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
server_type = detect_local_server_type(base_url)
|
|
except Exception:
|
|
server_type = None
|
|
|
|
try:
|
|
with httpx.Client(timeout=3.0) as client:
|
|
# Ollama: /api/show returns model details with context info
|
|
if server_type == "ollama":
|
|
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# Check model_info for context length
|
|
model_info = data.get("model_info", {})
|
|
for key, value in model_info.items():
|
|
if "context_length" in key and isinstance(value, (int, float)):
|
|
return int(value)
|
|
# Check parameters string for num_ctx
|
|
params = data.get("parameters", "")
|
|
if "num_ctx" in params:
|
|
for line in params.split("\n"):
|
|
if "num_ctx" in line:
|
|
parts = line.strip().split()
|
|
if len(parts) >= 2:
|
|
try:
|
|
return int(parts[-1])
|
|
except ValueError:
|
|
pass
|
|
|
|
# LM Studio native API: /api/v1/models returns max_context_length.
|
|
# This is more reliable than the OpenAI-compat /v1/models which
|
|
# doesn't include context window information for LM Studio servers.
|
|
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
|
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
|
if server_type == "lm-studio":
|
|
resp = client.get(f"{server_url}/api/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
for m in data.get("models", []):
|
|
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
|
# Prefer loaded instance context (actual runtime value)
|
|
for inst in m.get("loaded_instances", []):
|
|
cfg = inst.get("config", {})
|
|
ctx = cfg.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
# Fall back to max_context_length (theoretical model max)
|
|
ctx = m.get("max_context_length") or m.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
|
resp = client.get(f"{server_url}/v1/models/{model}")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# vLLM returns max_model_len
|
|
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# Try /v1/models and find the model in the list.
|
|
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
|
resp = client.get(f"{server_url}/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
models_list = data.get("data", [])
|
|
for m in models_list:
|
|
if _model_id_matches(m.get("id", ""), model):
|
|
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def get_model_context_length(
|
|
model: str,
|
|
base_url: str = "",
|
|
api_key: str = "",
|
|
config_context_length: int | None = None,
|
|
) -> int:
|
|
"""Get the context length for a model.
|
|
|
|
Resolution order:
|
|
0. Explicit config override (model.context_length in config.yaml)
|
|
1. Persistent cache (previously discovered via probing)
|
|
2. Active endpoint metadata (/models for explicit custom endpoints)
|
|
3. Local server query (for local endpoints when model not in /models list)
|
|
4. OpenRouter API metadata
|
|
5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
|
6. First probe tier (2M) — will be narrowed on first context error
|
|
"""
|
|
# 0. Explicit config override — user knows best
|
|
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
|
return config_context_length
|
|
|
|
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
|
# "model-name") so cache lookups and server queries use the bare ID that
|
|
# local servers actually know about. Ollama "model:tag" colons are preserved.
|
|
model = _strip_provider_prefix(model)
|
|
|
|
# 1. Check persistent cache (model+provider)
|
|
if base_url:
|
|
cached = get_cached_context_length(model, base_url)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# 2. Active endpoint metadata for explicit custom routes
|
|
if _is_custom_endpoint(base_url):
|
|
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
|
matched = endpoint_metadata.get(model)
|
|
if not matched:
|
|
# Single-model servers: if only one model is loaded, use it
|
|
if len(endpoint_metadata) == 1:
|
|
matched = next(iter(endpoint_metadata.values()))
|
|
else:
|
|
# Fuzzy match: substring in either direction
|
|
for key, entry in endpoint_metadata.items():
|
|
if model in key or key in model:
|
|
matched = entry
|
|
break
|
|
if matched:
|
|
context_length = matched.get("context_length")
|
|
if isinstance(context_length, int):
|
|
return context_length
|
|
if not _is_known_provider_base_url(base_url):
|
|
# Explicit third-party endpoints should not borrow fuzzy global
|
|
# defaults from unrelated providers with similarly named models.
|
|
# But first try querying the local server directly.
|
|
if is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
logger.info(
|
|
"Could not detect context length for model %r at %s — "
|
|
"defaulting to %s tokens (probe-down). Set model.context_length "
|
|
"in config.yaml to override.",
|
|
model, base_url, f"{CONTEXT_PROBE_TIERS[0]:,}",
|
|
)
|
|
return CONTEXT_PROBE_TIERS[0]
|
|
|
|
# 3. OpenRouter API metadata
|
|
metadata = fetch_model_metadata()
|
|
if model in metadata:
|
|
return metadata[model].get("context_length", 128000)
|
|
|
|
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
|
for default_model, length in sorted(
|
|
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
|
):
|
|
if default_model in model or model in default_model:
|
|
return length
|
|
|
|
# 5. Query local server for unknown models before defaulting to 2M
|
|
if base_url and is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
|
|
# 6. Unknown model — start at highest probe tier
|
|
return CONTEXT_PROBE_TIERS[0]
|
|
|
|
|
|
def estimate_tokens_rough(text: str) -> int:
|
|
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // 4
|
|
|
|
|
|
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
|
"""Rough token estimate for a message list (pre-flight only)."""
|
|
total_chars = sum(len(str(msg)) for msg in messages)
|
|
return total_chars // 4
|