Files
hermes-agent/agent/model_metadata.py
Peppi Littera ec5fdb8b92 feat: query local servers for actual context window size
Custom endpoints (LM Studio, Ollama, vLLM, llama.cpp) silently fall
back to 2M tokens when /v1/models doesn't include context_length.

Adds _query_local_context_length() which queries server-specific APIs:
- LM Studio: /api/v1/models (max_context_length + loaded instances)
- Ollama: /api/show (model_info + num_ctx parameters)
- llama.cpp: /props (n_ctx from default_generation_settings)
- vLLM: /v1/models/{model} (max_model_len)

Prefers loaded instance context over max (e.g., 122K loaded vs 1M max).
Results are cached via save_context_length() to avoid repeated queries.

Also fixes detect_local_server_type() misidentifying LM Studio as
Ollama (LM Studio returns 200 for /api/tags with an error body).
2026-03-19 21:32:04 +01:00

771 lines
29 KiB
Python

"""Model metadata, context lengths, and token estimation utilities.
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
and run_agent.py for pre-flight context checks.
"""
import logging
import os
import re
import time
from pathlib import Path
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import requests
import yaml
from hermes_constants import OPENROUTER_MODELS_URL
logger = logging.getLogger(__name__)
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
_model_metadata_cache_time: float = 0
_MODEL_CACHE_TTL = 3600
_endpoint_model_metadata_cache: Dict[str, Dict[str, Dict[str, Any]]] = {}
_endpoint_model_metadata_cache_time: Dict[str, float] = {}
_ENDPOINT_MODEL_CACHE_TTL = 300
# Descending tiers for context length probing when the model is unknown.
# We start high and step down on context-length errors until one works.
CONTEXT_PROBE_TIERS = [
2_000_000,
1_000_000,
512_000,
200_000,
128_000,
64_000,
32_000,
]
DEFAULT_CONTEXT_LENGTHS = {
"anthropic/claude-opus-4": 200000,
"anthropic/claude-opus-4.5": 200000,
"anthropic/claude-opus-4.6": 200000,
"anthropic/claude-sonnet-4": 200000,
"anthropic/claude-sonnet-4-20250514": 200000,
"anthropic/claude-sonnet-4.5": 200000,
"anthropic/claude-sonnet-4.6": 200000,
"anthropic/claude-haiku-4.5": 200000,
# Bare Anthropic model IDs (for native API provider)
"claude-opus-4-6": 200000,
"claude-sonnet-4-6": 200000,
"claude-opus-4-5-20251101": 200000,
"claude-sonnet-4-5-20250929": 200000,
"claude-opus-4-1-20250805": 200000,
"claude-opus-4-20250514": 200000,
"claude-sonnet-4-20250514": 200000,
"claude-haiku-4-5-20251001": 200000,
"openai/gpt-5": 128000,
"openai/gpt-4.1": 1047576,
"openai/gpt-4.1-mini": 1047576,
"openai/gpt-4o": 128000,
"openai/gpt-4-turbo": 128000,
"openai/gpt-4o-mini": 128000,
"google/gemini-3-pro-preview": 1048576,
"google/gemini-3-flash": 1048576,
"google/gemini-2.5-flash": 1048576,
"google/gemini-2.0-flash": 1048576,
"google/gemini-2.5-pro": 1048576,
"deepseek/deepseek-v3.2": 65536,
"meta-llama/llama-3.3-70b-instruct": 131072,
"deepseek/deepseek-chat-v3": 65536,
"qwen/qwen-2.5-72b-instruct": 32768,
"glm-4.7": 202752,
"glm-5": 202752,
"glm-4.5": 131072,
"glm-4.5-flash": 131072,
"kimi-for-coding": 262144,
"kimi-k2.5": 262144,
"kimi-k2-thinking": 262144,
"kimi-k2-thinking-turbo": 262144,
"kimi-k2-turbo-preview": 262144,
"kimi-k2-0905-preview": 131072,
"MiniMax-M2.7": 204800,
"MiniMax-M2.7-highspeed": 204800,
"MiniMax-M2.5": 204800,
"MiniMax-M2.5-highspeed": 204800,
"MiniMax-M2.1": 204800,
# OpenCode Zen models
"gpt-5.4-pro": 128000,
"gpt-5.4": 128000,
"gpt-5.3-codex": 128000,
"gpt-5.3-codex-spark": 128000,
"gpt-5.2": 128000,
"gpt-5.2-codex": 128000,
"gpt-5.1": 128000,
"gpt-5.1-codex": 128000,
"gpt-5.1-codex-max": 128000,
"gpt-5.1-codex-mini": 128000,
"gpt-5": 128000,
"gpt-5-codex": 128000,
"gpt-5-nano": 128000,
# Bare model IDs without provider prefix (avoid duplicates with entries above)
"claude-opus-4-5": 200000,
"claude-opus-4-1": 200000,
"claude-sonnet-4-5": 200000,
"claude-sonnet-4": 200000,
"claude-haiku-4-5": 200000,
"claude-3-5-haiku": 200000,
"gemini-3.1-pro": 1048576,
"gemini-3-pro": 1048576,
"gemini-3-flash": 1048576,
"minimax-m2.5": 204800,
"minimax-m2.5-free": 204800,
"minimax-m2.1": 204800,
"glm-4.6": 202752,
"kimi-k2": 262144,
"qwen3-coder": 32768,
"big-pickle": 128000,
# Alibaba Cloud / DashScope Qwen models
"qwen3.5-plus": 131072,
"qwen3-max": 131072,
"qwen3-coder-plus": 131072,
"qwen3-coder-next": 131072,
"qwen-plus-latest": 131072,
"qwen3.5-flash": 131072,
"qwen-vl-max": 32768,
}
_CONTEXT_LENGTH_KEYS = (
"context_length",
"context_window",
"max_context_length",
"max_position_embeddings",
"max_model_len",
"max_input_tokens",
"max_sequence_length",
"max_seq_len",
"n_ctx_train",
"n_ctx",
)
_MAX_COMPLETION_KEYS = (
"max_completion_tokens",
"max_output_tokens",
"max_tokens",
)
# Local server hostnames / address patterns
_LOCAL_HOSTS = ("localhost", "127.0.0.1", "::1", "0.0.0.0")
def _normalize_base_url(base_url: str) -> str:
return (base_url or "").strip().rstrip("/")
def _is_openrouter_base_url(base_url: str) -> bool:
return "openrouter.ai" in _normalize_base_url(base_url).lower()
def _is_custom_endpoint(base_url: str) -> bool:
normalized = _normalize_base_url(base_url)
return bool(normalized) and not _is_openrouter_base_url(normalized)
def _is_known_provider_base_url(base_url: str) -> bool:
normalized = _normalize_base_url(base_url)
if not normalized:
return False
parsed = urlparse(normalized if "://" in normalized else f"https://{normalized}")
host = parsed.netloc.lower() or parsed.path.lower()
known_hosts = (
"api.openai.com",
"chatgpt.com",
"api.anthropic.com",
"api.z.ai",
"api.moonshot.ai",
"api.kimi.com",
"api.minimax",
)
return any(known_host in host for known_host in known_hosts)
def is_local_endpoint(base_url: str) -> bool:
"""Return True if base_url points to a local machine (localhost / RFC-1918 / WSL)."""
normalized = _normalize_base_url(base_url)
if not normalized:
return False
url = normalized if "://" in normalized else f"http://{normalized}"
try:
parsed = urlparse(url)
host = parsed.hostname or ""
except Exception:
return False
if host in _LOCAL_HOSTS:
return True
# RFC-1918 private ranges and link-local
import ipaddress
try:
addr = ipaddress.ip_address(host)
return addr.is_private or addr.is_loopback or addr.is_link_local
except ValueError:
pass
# Bare IP that looks like a private range (e.g. 172.26.x.x for WSL)
parts = host.split(".")
if len(parts) == 4:
try:
first, second = int(parts[0]), int(parts[1])
if first == 10:
return True
if first == 172 and 16 <= second <= 31:
return True
if first == 192 and second == 168:
return True
except ValueError:
pass
return False
def detect_local_server_type(base_url: str) -> Optional[str]:
"""Detect which local server is running at base_url by probing known endpoints.
Returns one of: "ollama", "lm-studio", "vllm", "llamacpp", or None.
"""
import httpx
normalized = _normalize_base_url(base_url)
server_url = normalized
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
with httpx.Client(timeout=2.0) as client:
# LM Studio exposes /api/v1/models — check first (most specific)
try:
r = client.get(f"{server_url}/api/v1/models")
if r.status_code == 200:
return "lm-studio"
except Exception:
pass
# Ollama exposes /api/tags and responds with {"models": [...]}
# LM Studio returns {"error": "Unexpected endpoint"} with status 200
# on this path, so we must verify the response contains "models".
try:
r = client.get(f"{server_url}/api/tags")
if r.status_code == 200:
try:
data = r.json()
if "models" in data:
return "ollama"
except Exception:
pass
except Exception:
pass
# llama.cpp exposes /props
try:
r = client.get(f"{server_url}/props")
if r.status_code == 200 and "default_generation_settings" in r.text:
return "llamacpp"
except Exception:
pass
# vLLM: /version
try:
r = client.get(f"{server_url}/version")
if r.status_code == 200:
data = r.json()
if "version" in data:
return "vllm"
except Exception:
pass
except Exception:
pass
return None
def _iter_nested_dicts(value: Any):
if isinstance(value, dict):
yield value
for nested in value.values():
yield from _iter_nested_dicts(nested)
elif isinstance(value, list):
for item in value:
yield from _iter_nested_dicts(item)
def _coerce_reasonable_int(value: Any, minimum: int = 1024, maximum: int = 10_000_000) -> Optional[int]:
try:
if isinstance(value, bool):
return None
if isinstance(value, str):
value = value.strip().replace(",", "")
result = int(value)
except (TypeError, ValueError):
return None
if minimum <= result <= maximum:
return result
return None
def _extract_first_int(payload: Dict[str, Any], keys: tuple[str, ...]) -> Optional[int]:
keyset = {key.lower() for key in keys}
for mapping in _iter_nested_dicts(payload):
for key, value in mapping.items():
if str(key).lower() not in keyset:
continue
coerced = _coerce_reasonable_int(value)
if coerced is not None:
return coerced
return None
def _extract_context_length(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _CONTEXT_LENGTH_KEYS)
def _extract_max_completion_tokens(payload: Dict[str, Any]) -> Optional[int]:
return _extract_first_int(payload, _MAX_COMPLETION_KEYS)
def _extract_pricing(payload: Dict[str, Any]) -> Dict[str, Any]:
alias_map = {
"prompt": ("prompt", "input", "input_cost_per_token", "prompt_token_cost"),
"completion": ("completion", "output", "output_cost_per_token", "completion_token_cost"),
"request": ("request", "request_cost"),
"cache_read": ("cache_read", "cached_prompt", "input_cache_read", "cache_read_cost_per_token"),
"cache_write": ("cache_write", "cache_creation", "input_cache_write", "cache_write_cost_per_token"),
}
for mapping in _iter_nested_dicts(payload):
normalized = {str(key).lower(): value for key, value in mapping.items()}
if not any(any(alias in normalized for alias in aliases) for aliases in alias_map.values()):
continue
pricing: Dict[str, Any] = {}
for target, aliases in alias_map.items():
for alias in aliases:
if alias in normalized and normalized[alias] not in (None, ""):
pricing[target] = normalized[alias]
break
if pricing:
return pricing
return {}
def _add_model_aliases(cache: Dict[str, Dict[str, Any]], model_id: str, entry: Dict[str, Any]) -> None:
cache[model_id] = entry
if "/" in model_id:
bare_model = model_id.split("/", 1)[1]
cache.setdefault(bare_model, entry)
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
global _model_metadata_cache, _model_metadata_cache_time
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
return _model_metadata_cache
try:
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
response.raise_for_status()
data = response.json()
cache = {}
for model in data.get("data", []):
model_id = model.get("id", "")
entry = {
"context_length": model.get("context_length", 128000),
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
"name": model.get("name", model_id),
"pricing": model.get("pricing", {}),
}
_add_model_aliases(cache, model_id, entry)
canonical = model.get("canonical_slug", "")
if canonical and canonical != model_id:
_add_model_aliases(cache, canonical, entry)
_model_metadata_cache = cache
_model_metadata_cache_time = time.time()
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
return cache
except Exception as e:
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
return _model_metadata_cache or {}
def fetch_endpoint_model_metadata(
base_url: str,
api_key: str = "",
force_refresh: bool = False,
) -> Dict[str, Dict[str, Any]]:
"""Fetch model metadata from an OpenAI-compatible ``/models`` endpoint.
This is used for explicit custom endpoints where hardcoded global model-name
defaults are unreliable. Results are cached in memory per base URL.
"""
normalized = _normalize_base_url(base_url)
if not normalized or _is_openrouter_base_url(normalized):
return {}
if not force_refresh:
cached = _endpoint_model_metadata_cache.get(normalized)
cached_at = _endpoint_model_metadata_cache_time.get(normalized, 0)
if cached is not None and (time.time() - cached_at) < _ENDPOINT_MODEL_CACHE_TTL:
return cached
candidates = [normalized]
if normalized.endswith("/v1"):
alternate = normalized[:-3].rstrip("/")
else:
alternate = normalized + "/v1"
if alternate and alternate not in candidates:
candidates.append(alternate)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
last_error: Optional[Exception] = None
for candidate in candidates:
url = candidate.rstrip("/") + "/models"
try:
response = requests.get(url, headers=headers, timeout=10)
response.raise_for_status()
payload = response.json()
cache: Dict[str, Dict[str, Any]] = {}
for model in payload.get("data", []):
if not isinstance(model, dict):
continue
model_id = model.get("id")
if not model_id:
continue
entry: Dict[str, Any] = {"name": model.get("name", model_id)}
context_length = _extract_context_length(model)
if context_length is not None:
entry["context_length"] = context_length
max_completion_tokens = _extract_max_completion_tokens(model)
if max_completion_tokens is not None:
entry["max_completion_tokens"] = max_completion_tokens
pricing = _extract_pricing(model)
if pricing:
entry["pricing"] = pricing
_add_model_aliases(cache, model_id, entry)
# If this is a llama.cpp server, query /props for actual allocated context
is_llamacpp = any(
m.get("owned_by") == "llamacpp"
for m in payload.get("data", []) if isinstance(m, dict)
)
if is_llamacpp:
try:
props_url = candidate.rstrip("/").replace("/v1", "") + "/props"
props_resp = requests.get(props_url, headers=headers, timeout=5)
if props_resp.ok:
props = props_resp.json()
gen_settings = props.get("default_generation_settings", {})
n_ctx = gen_settings.get("n_ctx")
model_alias = props.get("model_alias", "")
if n_ctx and model_alias and model_alias in cache:
cache[model_alias]["context_length"] = n_ctx
except Exception:
pass
_endpoint_model_metadata_cache[normalized] = cache
_endpoint_model_metadata_cache_time[normalized] = time.time()
return cache
except Exception as exc:
last_error = exc
if last_error:
logger.debug("Failed to fetch model metadata from %s/models: %s", normalized, last_error)
_endpoint_model_metadata_cache[normalized] = {}
_endpoint_model_metadata_cache_time[normalized] = time.time()
return {}
def _get_context_cache_path() -> Path:
"""Return path to the persistent context length cache file."""
hermes_home = Path(os.environ.get("HERMES_HOME", Path.home() / ".hermes"))
return hermes_home / "context_length_cache.yaml"
def _load_context_cache() -> Dict[str, int]:
"""Load the model+provider -> context_length cache from disk."""
path = _get_context_cache_path()
if not path.exists():
return {}
try:
with open(path) as f:
data = yaml.safe_load(f) or {}
return data.get("context_lengths", {})
except Exception as e:
logger.debug("Failed to load context length cache: %s", e)
return {}
def save_context_length(model: str, base_url: str, length: int) -> None:
"""Persist a discovered context length for a model+provider combo.
Cache key is ``model@base_url`` so the same model name served from
different providers can have different limits.
"""
key = f"{model}@{base_url}"
cache = _load_context_cache()
if cache.get(key) == length:
return # already stored
cache[key] = length
path = _get_context_cache_path()
try:
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
yaml.dump({"context_lengths": cache}, f, default_flow_style=False)
logger.info("Cached context length %s -> %s tokens", key, f"{length:,}")
except Exception as e:
logger.debug("Failed to save context length cache: %s", e)
def get_cached_context_length(model: str, base_url: str) -> Optional[int]:
"""Look up a previously discovered context length for model+provider."""
key = f"{model}@{base_url}"
cache = _load_context_cache()
return cache.get(key)
def get_next_probe_tier(current_length: int) -> Optional[int]:
"""Return the next lower probe tier, or None if already at minimum."""
for tier in CONTEXT_PROBE_TIERS:
if tier < current_length:
return tier
return None
def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
"""Try to extract the actual context limit from an API error message.
Many providers include the limit in their error text, e.g.:
- "maximum context length is 32768 tokens"
- "context_length_exceeded: 131072"
- "Maximum context size 32768 exceeded"
- "model's max context length is 65536"
"""
error_lower = error_msg.lower()
# Pattern: look for numbers near context-related keywords
patterns = [
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
]
for pattern in patterns:
match = re.search(pattern, error_lower)
if match:
limit = int(match.group(1))
# Sanity check: must be a reasonable context length
if 1024 <= limit <= 10_000_000:
return limit
return None
def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
"""Return True if *candidate_id* (from server) matches *lookup_model* (configured).
Supports two forms:
- Exact match: "nvidia-nemotron-super-49b-v1" == "nvidia-nemotron-super-49b-v1"
- Slug match: "nvidia/nvidia-nemotron-super-49b-v1" matches "nvidia-nemotron-super-49b-v1"
(the part after the last "/" equals lookup_model)
This covers LM Studio's native API which stores models as "publisher/slug"
while users typically configure only the slug after the "local:" prefix.
"""
if candidate_id == lookup_model:
return True
# Slug match: basename of candidate equals the lookup name
if "/" in candidate_id and candidate_id.rsplit("/", 1)[1] == lookup_model:
return True
return False
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
"""Query a local server for the model's context length."""
import httpx
# Strip provider prefix (e.g., "local:model-name" → "model-name").
# LM Studio and Ollama don't use provider prefixes in their model IDs.
if ":" in model and not model.startswith("http"):
model = model.split(":", 1)[1]
# Strip /v1 suffix to get the server root
server_url = base_url.rstrip("/")
if server_url.endswith("/v1"):
server_url = server_url[:-3]
try:
server_type = detect_local_server_type(base_url)
except Exception:
server_type = None
try:
with httpx.Client(timeout=3.0) as client:
# Ollama: /api/show returns model details with context info
if server_type == "ollama":
resp = client.post(f"{server_url}/api/show", json={"name": model})
if resp.status_code == 200:
data = resp.json()
# Check model_info for context length
model_info = data.get("model_info", {})
for key, value in model_info.items():
if "context_length" in key and isinstance(value, (int, float)):
return int(value)
# Check parameters string for num_ctx
params = data.get("parameters", "")
if "num_ctx" in params:
for line in params.split("\n"):
if "num_ctx" in line:
parts = line.strip().split()
if len(parts) >= 2:
try:
return int(parts[-1])
except ValueError:
pass
# LM Studio native API: /api/v1/models returns max_context_length.
# This is more reliable than the OpenAI-compat /v1/models which
# doesn't include context window information for LM Studio servers.
# Use _model_id_matches for fuzzy matching: LM Studio stores models as
# "publisher/slug" but users configure only "slug" after "local:" prefix.
if server_type == "lm-studio":
resp = client.get(f"{server_url}/api/v1/models")
if resp.status_code == 200:
data = resp.json()
for m in data.get("models", []):
if _model_id_matches(m.get("key", ""), model) or _model_id_matches(m.get("id", ""), model):
# Prefer loaded instance context (actual runtime value)
for inst in m.get("loaded_instances", []):
cfg = inst.get("config", {})
ctx = cfg.get("context_length")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# Fall back to max_context_length (theoretical model max)
ctx = m.get("max_context_length") or m.get("context_length")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# LM Studio / vLLM / llama.cpp: try /v1/models/{model}
resp = client.get(f"{server_url}/v1/models/{model}")
if resp.status_code == 200:
data = resp.json()
# vLLM returns max_model_len
ctx = data.get("max_model_len") or data.get("context_length") or data.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
# Try /v1/models and find the model in the list.
# Use _model_id_matches to handle "publisher/slug" vs bare "slug".
resp = client.get(f"{server_url}/v1/models")
if resp.status_code == 200:
data = resp.json()
models_list = data.get("data", [])
for m in models_list:
if _model_id_matches(m.get("id", ""), model):
ctx = m.get("max_model_len") or m.get("context_length") or m.get("max_tokens")
if ctx and isinstance(ctx, (int, float)):
return int(ctx)
except Exception:
pass
return None
def get_model_context_length(
model: str,
base_url: str = "",
api_key: str = "",
config_context_length: int | None = None,
) -> int:
"""Get the context length for a model.
Resolution order:
0. Explicit config override (model.context_length in config.yaml)
1. Persistent cache (previously discovered via probing)
2. Active endpoint metadata (/models for explicit custom endpoints)
3. Local server query (for local endpoints when model not in /models list)
4. OpenRouter API metadata
5. Hardcoded DEFAULT_CONTEXT_LENGTHS (fuzzy match for hosted routes only)
6. First probe tier (2M) — will be narrowed on first context error
"""
# 0. Explicit config override — user knows best
if config_context_length is not None and isinstance(config_context_length, int) and config_context_length > 0:
return config_context_length
# Normalise provider-prefixed model names (e.g. "local:model-name" →
# "model-name") so cache lookups and server queries use the bare ID that
# local servers actually know about.
if ":" in model and not model.startswith("http"):
model = model.split(":", 1)[1]
# 1. Check persistent cache (model+provider)
if base_url:
cached = get_cached_context_length(model, base_url)
if cached is not None:
return cached
# 2. Active endpoint metadata for explicit custom routes
if _is_custom_endpoint(base_url):
endpoint_metadata = fetch_endpoint_model_metadata(base_url, api_key=api_key)
matched = endpoint_metadata.get(model)
if not matched:
# Single-model servers: if only one model is loaded, use it
if len(endpoint_metadata) == 1:
matched = next(iter(endpoint_metadata.values()))
else:
# Fuzzy match: substring in either direction
for key, entry in endpoint_metadata.items():
if model in key or key in model:
matched = entry
break
if matched:
context_length = matched.get("context_length")
if isinstance(context_length, int):
return context_length
if not _is_known_provider_base_url(base_url):
# Explicit third-party endpoints should not borrow fuzzy global
# defaults from unrelated providers with similarly named models.
# But first try querying the local server directly.
if is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
logger.info(
"Could not detect context length for model %r at %s"
"defaulting to %s tokens (probe-down). Set model.context_length "
"in config.yaml to override.",
model, base_url, f"{CONTEXT_PROBE_TIERS[0]:,}",
)
return CONTEXT_PROBE_TIERS[0]
# 3. OpenRouter API metadata
metadata = fetch_model_metadata()
if model in metadata:
return metadata[model].get("context_length", 128000)
# 4. Hardcoded defaults (fuzzy match — longest key first for specificity)
for default_model, length in sorted(
DEFAULT_CONTEXT_LENGTHS.items(), key=lambda x: len(x[0]), reverse=True
):
if default_model in model or model in default_model:
return length
# 5. Query local server for unknown models before defaulting to 2M
if base_url and is_local_endpoint(base_url):
local_ctx = _query_local_context_length(model, base_url)
if local_ctx and local_ctx > 0:
save_context_length(model, base_url, local_ctx)
return local_ctx
# 6. Unknown model — start at highest probe tier
return CONTEXT_PROBE_TIERS[0]
def estimate_tokens_rough(text: str) -> int:
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
if not text:
return 0
return len(text) // 4
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
"""Rough token estimate for a message list (pre-flight only)."""
total_chars = sum(len(str(msg)) for msg in messages)
return total_chars // 4