Custom endpoints (LM Studio, Ollama, vLLM, llama.cpp) silently fall
back to 2M tokens when /v1/models doesn't include context_length.
Adds _query_local_context_length() which queries server-specific APIs:
- LM Studio: /api/v1/models (max_context_length + loaded instances)
- Ollama: /api/show (model_info + num_ctx parameters)
- llama.cpp: /props (n_ctx from default_generation_settings)
- vLLM: /v1/models/{model} (max_model_len)
Prefers loaded instance context over max (e.g., 122K loaded vs 1M max).
Results are cached via save_context_length() to avoid repeated queries.
Also fixes detect_local_server_type() misidentifying LM Studio as
Ollama (LM Studio returns 200 for /api/tags with an error body).
771 lines
29 KiB
Python
771 lines
29 KiB
Python
"""Model metadata, context lengths, and token estimation utilities.
|
|
|
|
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
|
|
and run_agent.py for pre-flight context checks.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.parse import urlparse
|
|
|
|
import requests
|
|
import yaml
|
|
|
|
from hermes_constants import OPENROUTER_MODELS_URL
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
|
_model_metadata_cache_time: float = 0
|
|
_MODEL_CACHE_TTL = 3600
|
|
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
|
|
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
|
|
_ENDPOINT_MODEL_CACHE_TTL = 300
|
|
|
|
# Descending tiers for context length probing when the model is unknown.
|
|
# We start high and step down on context-length errors until one works.
|
|
CONTEXT_PROBE_TIERS = [
|
|
2_000_000,
|
|
1_000_000,
|
|
512_000,
|
|
200_000,
|
|
128_000,
|
|
64_000,
|
|
32_000,
|
|
]
|
|
|
|
DEFAULT_CONTEXT_LENGTHS = {
|
|
"anthropic/claude-opus-4": 200000,
|
|
"anthropic/claude-opus-4.5": 200000,
|
|
"anthropic/claude-opus-4.6": 200000,
|
|
"anthropic/claude-sonnet-4": 200000,
|
|
"anthropic/claude-sonnet-4-20250514": 200000,
|
|
"anthropic/claude-sonnet-4.5": 200000,
|
|
"anthropic/claude-sonnet-4.6": 200000,
|
|
"anthropic/claude-haiku-4.5": 200000,
|
|
# Bare Anthropic model IDs (for native API provider)
|
|
"claude-opus-4-6": 200000,
|
|
"claude-sonnet-4-6": 200000,
|
|
"claude-opus-4-5-20251101": 200000,
|
|
"claude-sonnet-4-5-20250929": 200000,
|
|
"claude-opus-4-1-20250805": 200000,
|
|
"claude-opus-4-20250514": 200000,
|
|
"claude-sonnet-4-20250514": 200000,
|
|
"claude-haiku-4-5-20251001": 200000,
|
|
"openai/gpt-5": 128000,
|
|
"openai/gpt-4.1": 1047576,
|
|
"openai/gpt-4.1-mini": 1047576,
|
|
"openai/gpt-4o": 128000,
|
|
"openai/gpt-4-turbo": 128000,
|
|
"openai/gpt-4o-mini": 128000,
|
|
"google/gemini-3-pro-preview": 1048576,
|
|
"google/gemini-3-flash": 1048576,
|
|
"google/gemini-2.5-flash": 1048576,
|
|
"google/gemini-2.0-flash": 1048576,
|
|
"google/gemini-2.5-pro": 1048576,
|
|
"deepseek/deepseek-v3.2": 65536,
|
|
"meta-llama/llama-3.3-70b-instruct": 131072,
|
|
"deepseek/deepseek-chat-v3": 65536,
|
|
"qwen/qwen-2.5-72b-instruct": 32768,
|
|
"glm-4.7": 202752,
|
|
"glm-5": 202752,
|
|
"glm-4.5": 131072,
|
|
"glm-4.5-flash": 131072,
|
|
"kimi-for-coding": 262144,
|
|
"kimi-k2.5": 262144,
|
|
"kimi-k2-thinking": 262144,
|
|
"kimi-k2-thinking-turbo": 262144,
|
|
"kimi-k2-turbo-preview": 262144,
|
|
"kimi-k2-0905-preview": 131072,
|
|
"MiniMax-M2.7": 204800,
|
|
"MiniMax-M2.7-highspeed": 204800,
|
|
"MiniMax-M2.5": 204800,
|
|
"MiniMax-M2.5-highspeed": 204800,
|
|
"MiniMax-M2.1": 204800,
|
|
# OpenCode Zen models
|
|
"gpt-5.4-pro": 128000,
|
|
"gpt-5.4": 128000,
|
|
"gpt-5.3-codex": 128000,
|
|
"gpt-5.3-codex-spark": 128000,
|
|
"gpt-5.2": 128000,
|
|
"gpt-5.2-codex": 128000,
|
|
"gpt-5.1": 128000,
|
|
"gpt-5.1-codex": 128000,
|
|
"gpt-5.1-codex-max": 128000,
|
|
"gpt-5.1-codex-mini": 128000,
|
|
"gpt-5": 128000,
|
|
"gpt-5-codex": 128000,
|
|
"gpt-5-nano": 128000,
|
|
# Bare model IDs without provider prefix (avoid duplicates with entries above)
|
|
"claude-opus-4-5": 200000,
|
|
"claude-opus-4-1": 200000,
|
|
"claude-sonnet-4-5": 200000,
|
|
"claude-sonnet-4": 200000,
|
|
"claude-haiku-4-5": 200000,
|
|
"claude-3-5-haiku": 200000,
|
|
"gemini-3.1-pro": 1048576,
|
|
"gemini-3-pro": 1048576,
|
|
"gemini-3-flash": 1048576,
|
|
"minimax-m2.5": 204800,
|
|
"minimax-m2.5-free": 204800,
|
|
"minimax-m2.1": 204800,
|
|
"glm-4.6": 202752,
|
|
"kimi-k2": 262144,
|
|
"qwen3-coder": 32768,
|
|
"big-pickle": 128000,
|
|
# Alibaba Cloud / DashScope Qwen models
|
|
"qwen3.5-plus": 131072,
|
|
"qwen3-max": 131072,
|
|
"qwen3-coder-plus": 131072,
|
|
"qwen3-coder-next": 131072,
|
|
"qwen-plus-latest": 131072,
|
|
"qwen3.5-flash": 131072,
|
|
"qwen-vl-max": 32768,
|
|
}
|
|
|
|
_CONTEXT_LENGTH_KEYS = (
|
|
"context_length",
|
|
"context_window",
|
|
"max_context_length",
|
|
"max_position_embeddings",
|
|
"max_model_len",
|
|
"max_input_tokens",
|
|
"max_sequence_length",
|
|
"max_seq_len",
|
|
"n_ctx_train",
|
|
"n_ctx",
|
|
)
|
|
|
|
_MAX_COMPLETION_KEYS = (
|
|
"max_completion_tokens",
|
|
"max_output_tokens",
|
|
"max_tokens",
|
|
)
|
|
|
|
# Local server hostnames / address patterns
|
|
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
|
|
|
|
|
|
def _normalize_base_url(base_url: str) -> str:
|
|
return (base_url or "").strip().rstrip("/")
|
|
|
|
|
|
def _is_openrouter_base_url(base_url: str) -> bool:
|
|
return "openrouter.ai" in _normalize_base_url(base_url).lower()
|
|
|
|
|
|
def _is_custom_endpoint(base_url: str) -> bool:
|
|
normalized = _normalize_base_url(base_url)
|
|
return bool(normalized) and not _is_openrouter_base_url(normalized)
|
|
|
|
|
|
def _is_known_provider_base_url(base_url: str) -> bool:
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return False
|
|
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
|
|
host = parsed.netloc.lower() or parsed.path.lower()
|
|
known_hosts = (
|
|
"api.openai.com",
|
|
"chatgpt.com",
|
|
"api.anthropic.com",
|
|
"api.z.ai",
|
|
"api.moonshot.ai",
|
|
"api.kimi.com",
|
|
"api.minimax",
|
|
)
|
|
return any(known_host in host for known_host in known_hosts)
|
|
|
|
|
|
def is_local_endpoint(base_url: str) -> bool:
|
|
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized:
|
|
return False
|
|
url = normalized if "://" in normalized else f"http://{normalized}"
|
|
try:
|
|
parsed = urlparse(url)
|
|
host = parsed.hostname or ""
|
|
except Exception:
|
|
return False
|
|
if host in _LOCAL_HOSTS:
|
|
return True
|
|
# RFC-1918 private ranges and link-local
|
|
import ipaddress
|
|
try:
|
|
addr = ipaddress.ip_address(host)
|
|
return addr.is_private or addr.is_loopback or addr.is_link_local
|
|
except ValueError:
|
|
pass
|
|
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
|
|
parts = host.split(".")
|
|
if len(parts) == 4:
|
|
try:
|
|
first, second = int(parts[0]), int(parts[1])
|
|
if first == 10:
|
|
return True
|
|
if first == 172 and 16 <= second <= 31:
|
|
return True
|
|
if first == 192 and second == 168:
|
|
return True
|
|
except ValueError:
|
|
pass
|
|
return False
|
|
|
|
|
|
def detect_local_server_type(base_url: str) -> Optional[str]:
|
|
"""Detect which local server is running at base_url by probing known endpoints.
|
|
|
|
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
|
|
"""
|
|
import httpx
|
|
|
|
normalized = _normalize_base_url(base_url)
|
|
server_url = normalized
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
with httpx.Client(timeout=2.0) as client:
|
|
# LM Studio exposes /api/v1/models — check first (most specific)
|
|
try:
|
|
r = client.get(f"{server_url}/api/v1/models")
|
|
if r.status_code == 200:
|
|
return "lm-studio"
|
|
except Exception:
|
|
pass
|
|
# Ollama exposes /api/tags and responds with {"models": [...]}
|
|
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
|
|
# on this path, so we must verify the response contains "models".
|
|
try:
|
|
r = client.get(f"{server_url}/api/tags")
|
|
if r.status_code == 200:
|
|
try:
|
|
data = r.json()
|
|
if "models" in data:
|
|
return "ollama"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
# llama.cpp exposes /props
|
|
try:
|
|
r = client.get(f"{server_url}/props")
|
|
if r.status_code == 200 and "default_generation_settings" in r.text:
|
|
return "llamacpp"
|
|
except Exception:
|
|
pass
|
|
# vLLM: /version
|
|
try:
|
|
r = client.get(f"{server_url}/version")
|
|
if r.status_code == 200:
|
|
data = r.json()
|
|
if "version" in data:
|
|
return "vllm"
|
|
except Exception:
|
|
pass
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def _iter_nested_dicts(value: Any):
|
|
if isinstance(value, dict):
|
|
yield value
|
|
for nested in value.values():
|
|
yield from _iter_nested_dicts(nested)
|
|
elif isinstance(value, list):
|
|
for item in value:
|
|
yield from _iter_nested_dicts(item)
|
|
|
|
|
|
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
|
|
try:
|
|
if isinstance(value, bool):
|
|
return None
|
|
if isinstance(value, str):
|
|
value = value.strip().replace(",", "")
|
|
result = int(value)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
if minimum <= result <= maximum:
|
|
return result
|
|
return None
|
|
|
|
|
|
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
|
|
keyset = {key.lower() for key in keys}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
for key, value in mapping.items():
|
|
if str(key).lower() not in keyset:
|
|
continue
|
|
coerced = _coerce_reasonable_int(value)
|
|
if coerced is not None:
|
|
return coerced
|
|
return None
|
|
|
|
|
|
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
|
|
|
|
|
|
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
|
|
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
|
|
|
|
|
|
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
|
|
alias_map = {
|
|
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
|
|
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
|
|
"request": ("request", "request_cost"),
|
|
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
|
|
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
|
|
}
|
|
for mapping in _iter_nested_dicts(payload):
|
|
normalized = {str(key).lower(): value for key, value in mapping.items()}
|
|
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
|
|
continue
|
|
pricing: Dict[str, Any] = {}
|
|
for target, aliases in alias_map.items():
|
|
for alias in aliases:
|
|
if alias in normalized and normalized[alias] not in (None, ""):
|
|
pricing[target] = normalized[alias]
|
|
break
|
|
if pricing:
|
|
return pricing
|
|
return {}
|
|
|
|
|
|
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
|
|
cache[model_id] = entry
|
|
if "/" in model_id:
|
|
bare_model = model_id.split("/", 1)[1]
|
|
cache.setdefault(bare_model, entry)
|
|
|
|
|
|
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
|
global _model_metadata_cache, _model_metadata_cache_time
|
|
|
|
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
|
|
return _model_metadata_cache
|
|
|
|
try:
|
|
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
cache = {}
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
entry = {
|
|
"context_length": model.get("context_length", 128000),
|
|
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
|
"name": model.get("name", model_id),
|
|
"pricing": model.get("pricing", {}),
|
|
}
|
|
_add_model_aliases(cache, model_id, entry)
|
|
canonical = model.get("canonical_slug", "")
|
|
if canonical and canonical != model_id:
|
|
_add_model_aliases(cache, canonical, entry)
|
|
|
|
_model_metadata_cache = cache
|
|
_model_metadata_cache_time = time.time()
|
|
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
|
|
return cache
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
|
|
return _model_metadata_cache or {}
|
|
|
|
|
|
def fetch_endpoint_model_metadata(
|
|
base_url: str,
|
|
api_key: str = "",
|
|
force_refresh: bool = False,
|
|
) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
|
|
|
|
This is used for explicit custom endpoints where hardcoded global model-name
|
|
defaults are unreliable. Results are cached in memory per base URL.
|
|
"""
|
|
normalized = _normalize_base_url(base_url)
|
|
if not normalized or _is_openrouter_base_url(normalized):
|
|
return {}
|
|
|
|
if not force_refresh:
|
|
cached = _endpoint_model_metadata_cache.get(normalized)
|
|
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
|
|
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
|
|
return cached
|
|
|
|
candidates = [normalized]
|
|
if normalized.endswith("/v1"):
|
|
alternate = normalized[:-3].rstrip("/")
|
|
else:
|
|
alternate = normalized + "/v1"
|
|
if alternate and alternate not in candidates:
|
|
candidates.append(alternate)
|
|
|
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
|
last_error: Optional[Exception] = None
|
|
|
|
for candidate in candidates:
|
|
url = candidate.rstrip("/") + "/models"
|
|
try:
|
|
response = requests.get(url, headers=headers, timeout=10)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
cache: Dict[str, Dict[str, Any]] = {}
|
|
for model in payload.get("data", []):
|
|
if not isinstance(model, dict):
|
|
continue
|
|
model_id = model.get("id")
|
|
if not model_id:
|
|
continue
|
|
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
|
|
context_length = _extract_context_length(model)
|
|
if context_length is not None:
|
|
entry["context_length"] = context_length
|
|
max_completion_tokens = _extract_max_completion_tokens(model)
|
|
if max_completion_tokens is not None:
|
|
entry["max_completion_tokens"] = max_completion_tokens
|
|
pricing = _extract_pricing(model)
|
|
if pricing:
|
|
entry["pricing"] = pricing
|
|
_add_model_aliases(cache, model_id, entry)
|
|
|
|
# If this is a llama.cpp server, query /props for actual allocated context
|
|
is_llamacpp = any(
|
|
m.get("owned_by") == "llamacpp"
|
|
for m in payload.get("data", []) if isinstance(m, dict)
|
|
)
|
|
if is_llamacpp:
|
|
try:
|
|
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
|
|
props_resp = requests.get(props_url, headers=headers, timeout=5)
|
|
if props_resp.ok:
|
|
props = props_resp.json()
|
|
gen_settings = props.get("default_generation_settings", {})
|
|
n_ctx = gen_settings.get("n_ctx")
|
|
model_alias = props.get("model_alias", "")
|
|
if n_ctx and model_alias and model_alias in cache:
|
|
cache[model_alias]["context_length"] = n_ctx
|
|
except Exception:
|
|
pass
|
|
|
|
_endpoint_model_metadata_cache[normalized] = cache
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return cache
|
|
except Exception as exc:
|
|
last_error = exc
|
|
|
|
if last_error:
|
|
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
|
|
_endpoint_model_metadata_cache[normalized] = {}
|
|
_endpoint_model_metadata_cache_time[normalized] = time.time()
|
|
return {}
|
|
|
|
|
|
def _get_context_cache_path() -> Path:
|
|
"""Return path to the persistent context length cache file."""
|
|
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
|
|
return hermes_home / "context_length_cache.yaml"
|
|
|
|
|
|
def _load_context_cache() -> Dict[str, int]:
|
|
"""Load the model+provider -> context_length cache from disk."""
|
|
path = _get_context_cache_path()
|
|
if not path.exists():
|
|
return {}
|
|
try:
|
|
with open(path) as f:
|
|
data = yaml.safe_load(f) or {}
|
|
return data.get("context_lengths", {})
|
|
except Exception as e:
|
|
logger.debug("Failed to load context length cache: %s", e)
|
|
return {}
|
|
|
|
|
|
def save_context_length(model: str, base_url: str, length: int) -> None:
|
|
"""Persist a discovered context length for a model+provider combo.
|
|
|
|
Cache key is ``model@base_url`` so the same model name served from
|
|
different providers can have different limits.
|
|
"""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
if cache.get(key) == length:
|
|
return # already stored
|
|
cache[key] = length
|
|
path = _get_context_cache_path()
|
|
try:
|
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
with open(path, "w") as f:
|
|
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
|
|
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
|
|
except Exception as e:
|
|
logger.debug("Failed to save context length cache: %s", e)
|
|
|
|
|
|
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Look up a previously discovered context length for model+provider."""
|
|
key = f"{model}@{base_url}"
|
|
cache = _load_context_cache()
|
|
return cache.get(key)
|
|
|
|
|
|
def get_next_probe_tier(current_length: int) -> Optional[int]:
|
|
"""Return the next lower probe tier, or None if already at minimum."""
|
|
for tier in CONTEXT_PROBE_TIERS:
|
|
if tier < current_length:
|
|
return tier
|
|
return None
|
|
|
|
|
|
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
|
"""Try to extract the actual context limit from an API error message.
|
|
|
|
Many providers include the limit in their error text, e.g.:
|
|
- "maximum context length is 32768 tokens"
|
|
- "context_length_exceeded: 131072"
|
|
- "Maximum context size 32768 exceeded"
|
|
- "model's max context length is 65536"
|
|
"""
|
|
error_lower = error_msg.lower()
|
|
# Pattern: look for numbers near context-related keywords
|
|
patterns = [
|
|
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
|
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
|
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
|
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
|
]
|
|
for pattern in patterns:
|
|
match = re.search(pattern, error_lower)
|
|
if match:
|
|
limit = int(match.group(1))
|
|
# Sanity check: must be a reasonable context length
|
|
if 1024 <= limit <= 10_000_000:
|
|
return limit
|
|
return None
|
|
|
|
|
|
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
|
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
|
|
|
|
Supports two forms:
|
|
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
|
|
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
|
|
(the part after the last "/" equals lookup_model)
|
|
|
|
This covers LM Studio's native API which stores models as "publisher/slug"
|
|
while users typically configure only the slug after the "local:" prefix.
|
|
"""
|
|
if candidate_id == lookup_model:
|
|
return True
|
|
# Slug match: basename of candidate equals the lookup name
|
|
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
|
|
return True
|
|
return False
|
|
|
|
|
|
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
|
"""Query a local server for the model's context length."""
|
|
import httpx
|
|
|
|
# Strip provider prefix (e.g., "local:model-name" → "model-name").
|
|
# LM Studio and Ollama don't use provider prefixes in their model IDs.
|
|
if ":" in model and not model.startswith("http"):
|
|
model = model.split(":", 1)[1]
|
|
|
|
# Strip /v1 suffix to get the server root
|
|
server_url = base_url.rstrip("/")
|
|
if server_url.endswith("/v1"):
|
|
server_url = server_url[:-3]
|
|
|
|
try:
|
|
server_type = detect_local_server_type(base_url)
|
|
except Exception:
|
|
server_type = None
|
|
|
|
try:
|
|
with httpx.Client(timeout=3.0) as client:
|
|
# Ollama: /api/show returns model details with context info
|
|
if server_type == "ollama":
|
|
resp = client.post(f"{server_url}/api/show", json={"name": model})
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# Check model_info for context length
|
|
model_info = data.get("model_info", {})
|
|
for key, value in model_info.items():
|
|
if "context_length" in key and isinstance(value, (int, float)):
|
|
return int(value)
|
|
# Check parameters string for num_ctx
|
|
params = data.get("parameters", "")
|
|
if "num_ctx" in params:
|
|
for line in params.split("\n"):
|
|
if "num_ctx" in line:
|
|
parts = line.strip().split()
|
|
if len(parts) >= 2:
|
|
try:
|
|
return int(parts[-1])
|
|
except ValueError:
|
|
pass
|
|
|
|
# LM Studio native API: /api/v1/models returns max_context_length.
|
|
# This is more reliable than the OpenAI-compat /v1/models which
|
|
# doesn't include context window information for LM Studio servers.
|
|
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
|
|
# "publisher/slug" but users configure only "slug" after "local:" prefix.
|
|
if server_type == "lm-studio":
|
|
resp = client.get(f"{server_url}/api/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
for m in data.get("models", []):
|
|
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
|
|
# Prefer loaded instance context (actual runtime value)
|
|
for inst in m.get("loaded_instances", []):
|
|
cfg = inst.get("config", {})
|
|
ctx = cfg.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
# Fall back to max_context_length (theoretical model max)
|
|
ctx = m.get("max_context_length") or m.get("context_length")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
|
|
resp = client.get(f"{server_url}/v1/models/{model}")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
# vLLM returns max_model_len
|
|
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
|
|
# Try /v1/models and find the model in the list.
|
|
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
|
|
resp = client.get(f"{server_url}/v1/models")
|
|
if resp.status_code == 200:
|
|
data = resp.json()
|
|
models_list = data.get("data", [])
|
|
for m in models_list:
|
|
if _model_id_matches(m.get("id", ""), model):
|
|
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
|
|
if ctx and isinstance(ctx, (int, float)):
|
|
return int(ctx)
|
|
except Exception:
|
|
pass
|
|
|
|
return None
|
|
|
|
|
|
def get_model_context_length(
|
|
model: str,
|
|
base_url: str = "",
|
|
api_key: str = "",
|
|
config_context_length: int | None = None,
|
|
) -> int:
|
|
"""Get the context length for a model.
|
|
|
|
Resolution order:
|
|
0. Explicit config override (model.context_length in config.yaml)
|
|
1. Persistent cache (previously discovered via probing)
|
|
2. Active endpoint metadata (/models for explicit custom endpoints)
|
|
3. Local server query (for local endpoints when model not in /models list)
|
|
4. OpenRouter API metadata
|
|
5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
|
|
6. First probe tier (2M) — will be narrowed on first context error
|
|
"""
|
|
# 0. Explicit config override — user knows best
|
|
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
|
|
return config_context_length
|
|
|
|
# Normalise provider-prefixed model names (e.g. "local:model-name" →
|
|
# "model-name") so cache lookups and server queries use the bare ID that
|
|
# local servers actually know about.
|
|
if ":" in model and not model.startswith("http"):
|
|
model = model.split(":", 1)[1]
|
|
|
|
# 1. Check persistent cache (model+provider)
|
|
if base_url:
|
|
cached = get_cached_context_length(model, base_url)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
# 2. Active endpoint metadata for explicit custom routes
|
|
if _is_custom_endpoint(base_url):
|
|
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
|
|
matched = endpoint_metadata.get(model)
|
|
if not matched:
|
|
# Single-model servers: if only one model is loaded, use it
|
|
if len(endpoint_metadata) == 1:
|
|
matched = next(iter(endpoint_metadata.values()))
|
|
else:
|
|
# Fuzzy match: substring in either direction
|
|
for key, entry in endpoint_metadata.items():
|
|
if model in key or key in model:
|
|
matched = entry
|
|
break
|
|
if matched:
|
|
context_length = matched.get("context_length")
|
|
if isinstance(context_length, int):
|
|
return context_length
|
|
if not _is_known_provider_base_url(base_url):
|
|
# Explicit third-party endpoints should not borrow fuzzy global
|
|
# defaults from unrelated providers with similarly named models.
|
|
# But first try querying the local server directly.
|
|
if is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
logger.info(
|
|
"Could not detect context length for model %r at %s — "
|
|
"defaulting to %s tokens (probe-down). Set model.context_length "
|
|
"in config.yaml to override.",
|
|
model, base_url, f"{CONTEXT_PROBE_TIERS[0]:,}",
|
|
)
|
|
return CONTEXT_PROBE_TIERS[0]
|
|
|
|
# 3. OpenRouter API metadata
|
|
metadata = fetch_model_metadata()
|
|
if model in metadata:
|
|
return metadata[model].get("context_length", 128000)
|
|
|
|
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
|
|
for default_model, length in sorted(
|
|
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
|
|
):
|
|
if default_model in model or model in default_model:
|
|
return length
|
|
|
|
# 5. Query local server for unknown models before defaulting to 2M
|
|
if base_url and is_local_endpoint(base_url):
|
|
local_ctx = _query_local_context_length(model, base_url)
|
|
if local_ctx and local_ctx > 0:
|
|
save_context_length(model, base_url, local_ctx)
|
|
return local_ctx
|
|
|
|
# 6. Unknown model — start at highest probe tier
|
|
return CONTEXT_PROBE_TIERS[0]
|
|
|
|
|
|
def estimate_tokens_rough(text: str) -> int:
|
|
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // 4
|
|
|
|
|
|
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
|
"""Rough token estimate for a message list (pre-flight only)."""
|
|
total_chars = sum(len(str(msg)) for msg in messages)
|
|
return total_chars // 4
|