98 lines
3.3 KiB
Python
98 lines
3.3 KiB
Python
"""Model metadata, context lengths, and token estimation utilities.
|
|
|
|
Pure utility functions with no AIAgent dependency. Used by ContextCompressor
|
|
and run_agent.py for pre-flight context checks.
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
from typing import Any, Dict, List
|
|
|
|
import requests
|
|
|
|
from hermes_constants import OPENROUTER_MODELS_URL
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_model_metadata_cache: Dict[str, Dict[str, Any]] = {}
|
|
_model_metadata_cache_time: float = 0
|
|
_MODEL_CACHE_TTL = 3600
|
|
|
|
DEFAULT_CONTEXT_LENGTHS = {
|
|
"anthropic/claude-opus-4": 200000,
|
|
"anthropic/claude-opus-4.5": 200000,
|
|
"anthropic/claude-opus-4.6": 200000,
|
|
"anthropic/claude-sonnet-4": 200000,
|
|
"anthropic/claude-sonnet-4-20250514": 200000,
|
|
"anthropic/claude-haiku-4.5": 200000,
|
|
"openai/gpt-4o": 128000,
|
|
"openai/gpt-4-turbo": 128000,
|
|
"openai/gpt-4o-mini": 128000,
|
|
"google/gemini-2.0-flash": 1048576,
|
|
"google/gemini-2.5-pro": 1048576,
|
|
"meta-llama/llama-3.3-70b-instruct": 131072,
|
|
"deepseek/deepseek-chat-v3": 65536,
|
|
"qwen/qwen-2.5-72b-instruct": 32768,
|
|
}
|
|
|
|
|
|
def fetch_model_metadata(force_refresh: bool = False) -> Dict[str, Dict[str, Any]]:
|
|
"""Fetch model metadata from OpenRouter (cached for 1 hour)."""
|
|
global _model_metadata_cache, _model_metadata_cache_time
|
|
|
|
if not force_refresh and _model_metadata_cache and (time.time() - _model_metadata_cache_time) < _MODEL_CACHE_TTL:
|
|
return _model_metadata_cache
|
|
|
|
try:
|
|
response = requests.get(OPENROUTER_MODELS_URL, timeout=10)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
cache = {}
|
|
for model in data.get("data", []):
|
|
model_id = model.get("id", "")
|
|
cache[model_id] = {
|
|
"context_length": model.get("context_length", 128000),
|
|
"max_completion_tokens": model.get("top_provider", {}).get("max_completion_tokens", 4096),
|
|
"name": model.get("name", model_id),
|
|
"pricing": model.get("pricing", {}),
|
|
}
|
|
canonical = model.get("canonical_slug", "")
|
|
if canonical and canonical != model_id:
|
|
cache[canonical] = cache[model_id]
|
|
|
|
_model_metadata_cache = cache
|
|
_model_metadata_cache_time = time.time()
|
|
logger.debug("Fetched metadata for %s models from OpenRouter", len(cache))
|
|
return cache
|
|
|
|
except Exception as e:
|
|
logging.warning(f"Failed to fetch model metadata from OpenRouter: {e}")
|
|
return _model_metadata_cache or {}
|
|
|
|
|
|
def get_model_context_length(model: str) -> int:
|
|
"""Get the context length for a model (API first, then fallback defaults)."""
|
|
metadata = fetch_model_metadata()
|
|
if model in metadata:
|
|
return metadata[model].get("context_length", 128000)
|
|
|
|
for default_model, length in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if default_model in model or model in default_model:
|
|
return length
|
|
|
|
return 128000
|
|
|
|
|
|
def estimate_tokens_rough(text: str) -> int:
|
|
"""Rough token estimate (~4 chars/token) for pre-flight checks."""
|
|
if not text:
|
|
return 0
|
|
return len(text) // 4
|
|
|
|
|
|
def estimate_messages_tokens_rough(messages: List[Dict[str, Any]]) -> int:
|
|
"""Rough token estimate for a message list (pre-flight only)."""
|
|
total_chars = sum(len(str(msg)) for msg in messages)
|
|
return total_chars // 4
|