195 lines
5.6 KiB
Python
195 lines
5.6 KiB
Python
"""Helpers for optional cheap-vs-strong model routing."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import re
|
|
from typing import Any, Dict, Optional
|
|
|
|
from utils import is_truthy_value
|
|
|
|
_COMPLEX_KEYWORDS = {
|
|
"debug",
|
|
"debugging",
|
|
"implement",
|
|
"implementation",
|
|
"refactor",
|
|
"patch",
|
|
"traceback",
|
|
"stacktrace",
|
|
"exception",
|
|
"error",
|
|
"analyze",
|
|
"analysis",
|
|
"investigate",
|
|
"architecture",
|
|
"design",
|
|
"compare",
|
|
"benchmark",
|
|
"optimize",
|
|
"optimise",
|
|
"review",
|
|
"terminal",
|
|
"shell",
|
|
"tool",
|
|
"tools",
|
|
"pytest",
|
|
"test",
|
|
"tests",
|
|
"plan",
|
|
"planning",
|
|
"delegate",
|
|
"subagent",
|
|
"cron",
|
|
"docker",
|
|
"kubernetes",
|
|
}
|
|
|
|
_URL_RE = re.compile(r"https?://|www\.", re.IGNORECASE)
|
|
|
|
|
|
def _coerce_bool(value: Any, default: bool = False) -> bool:
|
|
return is_truthy_value(value, default=default)
|
|
|
|
|
|
def _coerce_int(value: Any, default: int) -> int:
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError):
|
|
return default
|
|
|
|
|
|
def choose_cheap_model_route(user_message: str, routing_config: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
|
"""Return the configured cheap-model route when a message looks simple.
|
|
|
|
Conservative by design: if the message has signs of code/tool/debugging/
|
|
long-form work, keep the primary model.
|
|
"""
|
|
cfg = routing_config or {}
|
|
if not _coerce_bool(cfg.get("enabled"), False):
|
|
return None
|
|
|
|
cheap_model = cfg.get("cheap_model") or {}
|
|
if not isinstance(cheap_model, dict):
|
|
return None
|
|
provider = str(cheap_model.get("provider") or "").strip().lower()
|
|
model = str(cheap_model.get("model") or "").strip()
|
|
if not provider or not model:
|
|
return None
|
|
|
|
text = (user_message or "").strip()
|
|
if not text:
|
|
return None
|
|
|
|
max_chars = _coerce_int(cfg.get("max_simple_chars"), 160)
|
|
max_words = _coerce_int(cfg.get("max_simple_words"), 28)
|
|
|
|
if len(text) > max_chars:
|
|
return None
|
|
if len(text.split()) > max_words:
|
|
return None
|
|
if text.count("\n") > 1:
|
|
return None
|
|
if "```" in text or "`" in text:
|
|
return None
|
|
if _URL_RE.search(text):
|
|
return None
|
|
|
|
lowered = text.lower()
|
|
words = {token.strip(".,:;!?()[]{}\"'`") for token in lowered.split()}
|
|
if words & _COMPLEX_KEYWORDS:
|
|
return None
|
|
|
|
route = dict(cheap_model)
|
|
route["provider"] = provider
|
|
route["model"] = model
|
|
route["routing_reason"] = "simple_turn"
|
|
return route
|
|
|
|
|
|
def resolve_turn_route(user_message: str, routing_config: Optional[Dict[str, Any]], primary: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Resolve the effective model/runtime for one turn.
|
|
|
|
Returns a dict with model/runtime/signature/label fields.
|
|
"""
|
|
route = choose_cheap_model_route(user_message, routing_config)
|
|
if not route:
|
|
return {
|
|
"model": primary.get("model"),
|
|
"runtime": {
|
|
"api_key": primary.get("api_key"),
|
|
"base_url": primary.get("base_url"),
|
|
"provider": primary.get("provider"),
|
|
"api_mode": primary.get("api_mode"),
|
|
"command": primary.get("command"),
|
|
"args": list(primary.get("args") or []),
|
|
"credential_pool": primary.get("credential_pool"),
|
|
},
|
|
"label": None,
|
|
"signature": (
|
|
primary.get("model"),
|
|
primary.get("provider"),
|
|
primary.get("base_url"),
|
|
primary.get("api_mode"),
|
|
primary.get("command"),
|
|
tuple(primary.get("args") or ()),
|
|
),
|
|
}
|
|
|
|
from hermes_cli.runtime_provider import resolve_runtime_provider
|
|
|
|
explicit_api_key = None
|
|
api_key_env = str(route.get("api_key_env") or "").strip()
|
|
if api_key_env:
|
|
explicit_api_key = os.getenv(api_key_env) or None
|
|
|
|
try:
|
|
runtime = resolve_runtime_provider(
|
|
requested=route.get("provider"),
|
|
explicit_api_key=explicit_api_key,
|
|
explicit_base_url=route.get("base_url"),
|
|
)
|
|
except Exception:
|
|
return {
|
|
"model": primary.get("model"),
|
|
"runtime": {
|
|
"api_key": primary.get("api_key"),
|
|
"base_url": primary.get("base_url"),
|
|
"provider": primary.get("provider"),
|
|
"api_mode": primary.get("api_mode"),
|
|
"command": primary.get("command"),
|
|
"args": list(primary.get("args") or []),
|
|
"credential_pool": primary.get("credential_pool"),
|
|
},
|
|
"label": None,
|
|
"signature": (
|
|
primary.get("model"),
|
|
primary.get("provider"),
|
|
primary.get("base_url"),
|
|
primary.get("api_mode"),
|
|
primary.get("command"),
|
|
tuple(primary.get("args") or ()),
|
|
),
|
|
}
|
|
|
|
return {
|
|
"model": route.get("model"),
|
|
"runtime": {
|
|
"api_key": runtime.get("api_key"),
|
|
"base_url": runtime.get("base_url"),
|
|
"provider": runtime.get("provider"),
|
|
"api_mode": runtime.get("api_mode"),
|
|
"command": runtime.get("command"),
|
|
"args": list(runtime.get("args") or []),
|
|
},
|
|
"label": f"smart route → {route.get('model')} ({runtime.get('provider')})",
|
|
"signature": (
|
|
route.get("model"),
|
|
runtime.get("provider"),
|
|
runtime.get("base_url"),
|
|
runtime.get("api_mode"),
|
|
runtime.get("command"),
|
|
tuple(runtime.get("args") or ()),
|
|
),
|
|
}
|