diff --git a/cli.py b/cli.py index efdaeee5..ce9e30b8 100644 --- a/cli.py +++ b/cli.py @@ -3562,151 +3562,83 @@ class HermesCLI: # Use original case so model names like "Anthropic/Claude-Opus-4" are preserved parts = cmd_original.split(maxsplit=1) if len(parts) > 1: - from hermes_cli.auth import resolve_provider - from hermes_cli.models import ( - parse_model_input, - validate_requested_model, - _PROVIDER_LABELS, - ) + from hermes_cli.model_switch import switch_model, switch_to_custom_provider raw_input = parts[1].strip() # Handle bare "/model custom" — switch to custom provider # and auto-detect the model from the endpoint. if raw_input.strip().lower() == "custom": - from hermes_cli.runtime_provider import ( - resolve_runtime_provider, - _auto_detect_local_model, - ) - try: - runtime = resolve_runtime_provider(requested="custom") - cust_base = runtime.get("base_url", "") - cust_key = runtime.get("api_key", "") - if not cust_base or "openrouter.ai" in cust_base: - print("(>_<) No custom endpoint configured.") - print(" Set model.base_url in config.yaml, or set OPENAI_BASE_URL in .env,") - print(" or run: hermes setup → Custom OpenAI-compatible endpoint") - return True - detected_model = _auto_detect_local_model(cust_base) - if detected_model: - self.model = detected_model - self.requested_provider = "custom" - self.provider = "custom" - self.api_key = cust_key - self.base_url = cust_base - self.agent = None - save_config_value("model.default", detected_model) - save_config_value("model.provider", "custom") - save_config_value("model.base_url", cust_base) - print(f"(^_^)b Model changed to: {detected_model} [provider: Custom]") - print(f" Endpoint: {cust_base}") - print(f" Status: connected (model auto-detected)") - else: - print(f"(>_<) Custom endpoint at {cust_base} is reachable but no single model was auto-detected.") - print(f" Specify the model explicitly: /model custom:") - except Exception as e: - print(f"(>_<) Could not resolve custom endpoint: {e}") + result = switch_to_custom_provider() + if result.success: + self.model = result.model + self.requested_provider = "custom" + self.provider = "custom" + self.api_key = result.api_key + self.base_url = result.base_url + self.agent = None + save_config_value("model.default", result.model) + save_config_value("model.provider", "custom") + save_config_value("model.base_url", result.base_url) + print(f"(^_^)b Model changed to: {result.model} [provider: Custom]") + print(f" Endpoint: {result.base_url}") + print(f" Status: connected (model auto-detected)") + else: + print(f"(>_<) {result.error_message}") return True - # Parse provider:model syntax (e.g. "openrouter:anthropic/claude-sonnet-4.5") + # Core model-switching pipeline (shared with gateway) current_provider = self.provider or self.requested_provider or "openrouter" - target_provider, new_model = parse_model_input(raw_input, current_provider) - # Auto-detect provider when no explicit provider:model syntax was used. - # Skip auto-detection for custom providers — the model name might - # coincidentally match a known provider's catalog, but the user - # intends to use it on their custom endpoint. Require explicit - # provider:model syntax (e.g. /model openai-codex:gpt-5.2-codex) - # to switch away from a custom endpoint. - _base = self.base_url or "" - is_custom = current_provider == "custom" or ( - "localhost" in _base or "127.0.0.1" in _base + result = switch_model( + raw_input, + current_provider, + current_base_url=self.base_url or "", + current_api_key=self.api_key or "", ) - if target_provider == current_provider and not is_custom: - from hermes_cli.models import detect_provider_for_model - detected = detect_provider_for_model(new_model, current_provider) - if detected: - target_provider, new_model = detected - provider_changed = target_provider != current_provider - # If provider is changing, re-resolve credentials for the new provider - api_key_for_probe = self.api_key - base_url_for_probe = self.base_url - if provider_changed: - try: - from hermes_cli.runtime_provider import resolve_runtime_provider - runtime = resolve_runtime_provider(requested=target_provider) - api_key_for_probe = runtime.get("api_key", "") - base_url_for_probe = runtime.get("base_url", "") - except Exception as e: - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - if target_provider == "custom": - print(f"(>_<) Custom endpoint not configured. Set OPENAI_BASE_URL and OPENAI_API_KEY,") - print(f" or run: hermes setup → Custom OpenAI-compatible endpoint") - else: - print(f"(>_<) Could not resolve credentials for provider '{provider_label}': {e}") - print(f"(^_^) Current model unchanged: {self.model}") - return True - - try: - validation = validate_requested_model( - new_model, - target_provider, - api_key=api_key_for_probe, - base_url=base_url_for_probe, - ) - except Exception: - validation = {"accepted": True, "persist": True, "recognized": False, "message": None} - - if not validation.get("accepted"): - print(f"(>_<) {validation.get('message')}") - print(f" Model unchanged: {self.model}") - if "Did you mean" not in (validation.get("message") or ""): - print(" Tip: Use /model to see available models, /provider to see providers") + if not result.success: + print(f"(>_<) {result.error_message}") + if "Did you mean" not in result.error_message: + print(f" Model unchanged: {self.model}") + if "credentials" not in result.error_message.lower(): + print(" Tip: Use /model to see available models, /provider to see providers") else: - self.model = new_model + self.model = result.new_model self.agent = None # Force re-init - if provider_changed: - self.requested_provider = target_provider - self.provider = target_provider - self.api_key = api_key_for_probe - self.base_url = base_url_for_probe + if result.provider_changed: + self.requested_provider = result.target_provider + self.provider = result.target_provider + self.api_key = result.api_key + self.base_url = result.base_url - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - provider_note = f" [provider: {provider_label}]" if provider_changed else "" + provider_note = f" [provider: {result.provider_label}]" if result.provider_changed else "" - if validation.get("persist"): - saved_model = save_config_value("model.default", new_model) - if provider_changed: - save_config_value("model.provider", target_provider) - # Persist base_url for custom endpoints so it - # survives restart; clear it when switching away - # from custom to prevent stale URLs leaking into - # the new provider's resolution (#2562 Phase 2). - if base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or ""): - save_config_value("model.base_url", base_url_for_probe) + if result.persist: + saved_model = save_config_value("model.default", result.new_model) + if result.provider_changed: + save_config_value("model.provider", result.target_provider) + # Persist base_url for custom endpoints; clear + # when switching away from custom (#2562 Phase 2). + if result.base_url and "openrouter.ai" not in (result.base_url or ""): + save_config_value("model.base_url", result.base_url) else: save_config_value("model.base_url", None) if saved_model: - print(f"(^_^)b Model changed to: {new_model}{provider_note} (saved to config)") + print(f"(^_^)b Model changed to: {result.new_model}{provider_note} (saved to config)") else: - print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)") + print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)") else: - message = validation.get("message") or "" - print(f"(^_^) Model changed to: {new_model}{provider_note} (this session only)") - if message: - print(f" Reason: {message}") + print(f"(^_^) Model changed to: {result.new_model}{provider_note} (this session only)") + if result.warning_message: + print(f" Reason: {result.warning_message}") print(" Note: Model will revert on restart. Use a verified model to save to config.") # Show endpoint info for custom providers - _target_is_custom = target_provider == "custom" or ( - base_url_for_probe and "openrouter.ai" not in (base_url_for_probe or "") - and ("localhost" in (base_url_for_probe or "") or "127.0.0.1" in (base_url_for_probe or "")) - ) - if _target_is_custom or (is_custom and not provider_changed): - endpoint = base_url_for_probe or self.base_url or "custom endpoint" + if result.is_custom_target: + endpoint = result.base_url or self.base_url or "custom endpoint" print(f" Endpoint: {endpoint}") - if not provider_changed: + if not result.provider_changed: print(f" Tip: To switch providers, use /model provider:model") print(f" e.g. /model openai-codex:gpt-5.2-codex") else: diff --git a/gateway/run.py b/gateway/run.py index 91276c2a..c8cfae5d 100644 --- a/gateway/run.py +++ b/gateway/run.py @@ -2854,117 +2854,10 @@ class GatewayRunner: # Handle bare "/model custom" — switch to custom provider # and auto-detect the model from the endpoint. if args.strip().lower() == "custom": - from hermes_cli.runtime_provider import ( - resolve_runtime_provider as _rtp_custom, - _auto_detect_local_model, - ) - try: - runtime = _rtp_custom(requested="custom") - cust_base = runtime.get("base_url", "") - if not cust_base or "openrouter.ai" in cust_base: - return ( - "⚠️ No custom endpoint configured.\n" - "Set `model.base_url` in config.yaml, or `OPENAI_BASE_URL` in .env,\n" - "or run: `hermes setup` → Custom OpenAI-compatible endpoint" - ) - detected_model = _auto_detect_local_model(cust_base) - if detected_model: - try: - user_config = {} - if config_path.exists(): - with open(config_path, encoding="utf-8") as f: - user_config = yaml.safe_load(f) or {} - if "model" not in user_config or not isinstance(user_config["model"], dict): - user_config["model"] = {} - user_config["model"]["default"] = detected_model - user_config["model"]["provider"] = "custom" - user_config["model"]["base_url"] = cust_base - with open(config_path, 'w', encoding="utf-8") as f: - yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) - except Exception as e: - return f"⚠️ Failed to save model change: {e}" - os.environ["HERMES_MODEL"] = detected_model - os.environ["HERMES_INFERENCE_PROVIDER"] = "custom" - self._effective_model = None - self._effective_provider = None - return ( - f"🤖 Model changed to `{detected_model}` (saved to config)\n" - f"**Provider:** Custom\n" - f"**Endpoint:** `{cust_base}`\n" - f"_Model auto-detected from endpoint. Takes effect on next message._" - ) - else: - return ( - f"⚠️ Custom endpoint at `{cust_base}` is reachable but no single model was auto-detected.\n" - f"Specify the model explicitly: `/model custom:`" - ) - except Exception as e: - return f"⚠️ Could not resolve custom endpoint: {e}" - - # Parse provider:model syntax - target_provider, new_model = parse_model_input(args, current_provider) - - # Detect custom/local provider — skip auto-detection to prevent - # silently accepting an OpenRouter model name on a localhost endpoint. - # Users must use explicit provider:model syntax to switch away. - _resolved_base = "" - try: - from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp - _resolved_base = _rtp(requested=current_provider).get("base_url", "") - except Exception: - pass - is_custom = current_provider == "custom" or ( - "localhost" in _resolved_base or "127.0.0.1" in _resolved_base - ) - - # Auto-detect provider when no explicit provider:model syntax was used - if target_provider == current_provider and not is_custom: - from hermes_cli.models import detect_provider_for_model - detected = detect_provider_for_model(new_model, current_provider) - if detected: - target_provider, new_model = detected - provider_changed = target_provider != current_provider - - # Resolve credentials for the target provider (for API probe) - api_key = os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "" - base_url = "https://openrouter.ai/api/v1" - if provider_changed: - try: - from hermes_cli.runtime_provider import resolve_runtime_provider - runtime = resolve_runtime_provider(requested=target_provider) - api_key = runtime.get("api_key", "") - base_url = runtime.get("base_url", "") - except Exception as e: - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - return f"⚠️ Could not resolve credentials for provider '{provider_label}': {e}" - else: - # Use current provider's base_url from config or registry - try: - from hermes_cli.runtime_provider import resolve_runtime_provider - runtime = resolve_runtime_provider(requested=current_provider) - api_key = runtime.get("api_key", "") - base_url = runtime.get("base_url", "") - except Exception: - pass - - # Validate the model against the live API - try: - validation = validate_requested_model( - new_model, - target_provider, - api_key=api_key, - base_url=base_url, - ) - except Exception: - validation = {"accepted": True, "persist": True, "recognized": False, "message": None} - - if not validation.get("accepted"): - msg = validation.get("message", "Invalid model") - tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else "" - return f"⚠️ {msg}{tip}" - - # Persist to config only if validation approves - if validation.get("persist"): + from hermes_cli.model_switch import switch_to_custom_provider + cust_result = switch_to_custom_provider() + if not cust_result.success: + return f"⚠️ {cust_result.error_message}" try: user_config = {} if config_path.exists(): @@ -2972,14 +2865,63 @@ class GatewayRunner: user_config = yaml.safe_load(f) or {} if "model" not in user_config or not isinstance(user_config["model"], dict): user_config["model"] = {} - user_config["model"]["default"] = new_model - if provider_changed: - user_config["model"]["provider"] = target_provider - # Persist base_url for custom endpoints so it survives - # restart; clear it when switching away from custom to - # prevent stale URLs leaking (#2562 Phase 2). - if base_url and "openrouter.ai" not in (base_url or ""): - user_config["model"]["base_url"] = base_url + user_config["model"]["default"] = cust_result.model + user_config["model"]["provider"] = "custom" + user_config["model"]["base_url"] = cust_result.base_url + with open(config_path, 'w', encoding="utf-8") as f: + yaml.dump(user_config, f, default_flow_style=False, sort_keys=False) + except Exception as e: + return f"⚠️ Failed to save model change: {e}" + os.environ["HERMES_MODEL"] = cust_result.model + os.environ["HERMES_INFERENCE_PROVIDER"] = "custom" + self._effective_model = None + self._effective_provider = None + return ( + f"🤖 Model changed to `{cust_result.model}` (saved to config)\n" + f"**Provider:** Custom\n" + f"**Endpoint:** `{cust_result.base_url}`\n" + f"_Model auto-detected from endpoint. Takes effect on next message._" + ) + + # Core model-switching pipeline (shared with CLI) + from hermes_cli.model_switch import switch_model + + # Resolve current base_url for is_custom detection + _resolved_base = "" + try: + from hermes_cli.runtime_provider import resolve_runtime_provider as _rtp + _resolved_base = _rtp(requested=current_provider).get("base_url", "") + except Exception: + pass + + result = switch_model( + args, + current_provider, + current_base_url=_resolved_base, + current_api_key=os.getenv("OPENROUTER_API_KEY") or os.getenv("OPENAI_API_KEY") or "", + ) + + if not result.success: + msg = result.error_message + tip = "\n\nUse `/model` to see available models, `/provider` to see providers" if "Did you mean" not in msg else "" + return f"⚠️ {msg}{tip}" + + # Persist to config only if validation approves + if result.persist: + try: + user_config = {} + if config_path.exists(): + with open(config_path, encoding="utf-8") as f: + user_config = yaml.safe_load(f) or {} + if "model" not in user_config or not isinstance(user_config["model"], dict): + user_config["model"] = {} + user_config["model"]["default"] = result.new_model + if result.provider_changed: + user_config["model"]["provider"] = result.target_provider + # Persist base_url for custom endpoints; clear when + # switching away from custom (#2562 Phase 2). + if result.base_url and "openrouter.ai" not in (result.base_url or ""): + user_config["model"]["base_url"] = result.base_url else: user_config["model"].pop("base_url", None) with open(config_path, 'w', encoding="utf-8") as f: @@ -2988,41 +2930,34 @@ class GatewayRunner: return f"⚠️ Failed to save model change: {e}" # Set env vars so the next agent run picks up the change - os.environ["HERMES_MODEL"] = new_model - if provider_changed: - os.environ["HERMES_INFERENCE_PROVIDER"] = target_provider + os.environ["HERMES_MODEL"] = result.new_model + if result.provider_changed: + os.environ["HERMES_INFERENCE_PROVIDER"] = result.target_provider - provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) - provider_note = f"\n**Provider:** {provider_label}" if provider_changed else "" + provider_note = f"\n**Provider:** {result.provider_label}" if result.provider_changed else "" warning = "" - if validation.get("message"): - warning = f"\n⚠️ {validation['message']}" + if result.warning_message: + warning = f"\n⚠️ {result.warning_message}" + + persist_note = "saved to config" if result.persist else "this session only — will revert on restart" - if validation.get("persist"): - persist_note = "saved to config" - else: - persist_note = "this session only — will revert on restart" # Clear fallback state since user explicitly chose a model self._effective_model = None self._effective_provider = None # Show endpoint info for custom providers - _target_is_custom = target_provider == "custom" or ( - base_url and "openrouter.ai" not in (base_url or "") - and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or "")) - ) custom_hint = "" - if _target_is_custom or (is_custom and not provider_changed): - endpoint = base_url or _resolved_base or "custom endpoint" + if result.is_custom_target: + endpoint = result.base_url or _resolved_base or "custom endpoint" custom_hint = f"\n**Endpoint:** `{endpoint}`" - if not provider_changed: + if not result.provider_changed: custom_hint += ( "\n_To switch providers, use_ `/model provider:model`" "\n_e.g._ `/model openrouter:anthropic/claude-sonnet-4`" ) - return f"🤖 Model changed to `{new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_" + return f"🤖 Model changed to `{result.new_model}` ({persist_note}){provider_note}{warning}{custom_hint}\n_(takes effect on next message)_" async def _handle_provider_command(self, event: MessageEvent) -> str: """Handle /provider command - show available providers.""" diff --git a/hermes_cli/model_switch.py b/hermes_cli/model_switch.py new file mode 100644 index 00000000..57ca5380 --- /dev/null +++ b/hermes_cli/model_switch.py @@ -0,0 +1,234 @@ +"""Shared model-switching logic for CLI and gateway /model commands. + +Both the CLI (cli.py) and gateway (gateway/run.py) /model handlers +share the same core pipeline: + + parse_model_input → is_custom detection → auto-detect provider + → credential resolution → validate model → return result + +This module extracts that shared pipeline into pure functions that +return result objects. The callers handle all platform-specific +concerns: state mutation, config persistence, output formatting. +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelSwitchResult: + """Result of a model switch attempt.""" + + success: bool + new_model: str = "" + target_provider: str = "" + provider_changed: bool = False + api_key: str = "" + base_url: str = "" + persist: bool = False + error_message: str = "" + warning_message: str = "" + is_custom_target: bool = False + provider_label: str = "" + + +@dataclass +class CustomAutoResult: + """Result of switching to bare 'custom' provider with auto-detect.""" + + success: bool + model: str = "" + base_url: str = "" + api_key: str = "" + error_message: str = "" + + +def switch_model( + raw_input: str, + current_provider: str, + current_base_url: str = "", + current_api_key: str = "", +) -> ModelSwitchResult: + """Core model-switching pipeline shared between CLI and gateway. + + Handles parsing, provider detection, credential resolution, and + model validation. Does NOT handle config persistence, state + mutation, or output formatting — those are caller responsibilities. + + Args: + raw_input: The user's model input (e.g. "claude-sonnet-4", + "zai:glm-5", "custom:local:qwen"). + current_provider: The currently active provider. + current_base_url: The currently active base URL (used for + is_custom detection). + current_api_key: The currently active API key. + + Returns: + ModelSwitchResult with all information the caller needs to + apply the switch and format output. + """ + from hermes_cli.models import ( + parse_model_input, + detect_provider_for_model, + validate_requested_model, + _PROVIDER_LABELS, + ) + from hermes_cli.runtime_provider import resolve_runtime_provider + + # Step 1: Parse provider:model syntax + target_provider, new_model = parse_model_input(raw_input, current_provider) + + # Step 2: Detect if we're currently on a custom endpoint + _base = current_base_url or "" + is_custom = current_provider == "custom" or ( + "localhost" in _base or "127.0.0.1" in _base + ) + + # Step 3: Auto-detect provider when no explicit provider:model syntax + # was used. Skip for custom providers — the model name might + # coincidentally match a known provider's catalog. + if target_provider == current_provider and not is_custom: + detected = detect_provider_for_model(new_model, current_provider) + if detected: + target_provider, new_model = detected + + provider_changed = target_provider != current_provider + + # Step 4: Resolve credentials for target provider + api_key = current_api_key + base_url = current_base_url + if provider_changed: + try: + runtime = resolve_runtime_provider(requested=target_provider) + api_key = runtime.get("api_key", "") + base_url = runtime.get("base_url", "") + except Exception as e: + provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) + if target_provider == "custom": + return ModelSwitchResult( + success=False, + target_provider=target_provider, + error_message=( + "No custom endpoint configured. Set model.base_url " + "in config.yaml, or set OPENAI_BASE_URL in .env, " + "or run: hermes setup → Custom OpenAI-compatible endpoint" + ), + ) + return ModelSwitchResult( + success=False, + target_provider=target_provider, + error_message=( + f"Could not resolve credentials for provider " + f"'{provider_label}': {e}" + ), + ) + else: + # Gateway also resolves for unchanged provider to get accurate + # base_url for validation probing. + try: + runtime = resolve_runtime_provider(requested=current_provider) + api_key = runtime.get("api_key", "") + base_url = runtime.get("base_url", "") + except Exception: + pass + + # Step 5: Validate the model + try: + validation = validate_requested_model( + new_model, + target_provider, + api_key=api_key, + base_url=base_url, + ) + except Exception: + validation = { + "accepted": True, + "persist": True, + "recognized": False, + "message": None, + } + + if not validation.get("accepted"): + msg = validation.get("message", "Invalid model") + return ModelSwitchResult( + success=False, + new_model=new_model, + target_provider=target_provider, + error_message=msg, + ) + + # Step 6: Build result + provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) + is_custom_target = target_provider == "custom" or ( + base_url + and "openrouter.ai" not in (base_url or "") + and ("localhost" in (base_url or "") or "127.0.0.1" in (base_url or "")) + ) + + return ModelSwitchResult( + success=True, + new_model=new_model, + target_provider=target_provider, + provider_changed=provider_changed, + api_key=api_key, + base_url=base_url, + persist=bool(validation.get("persist")), + warning_message=validation.get("message") or "", + is_custom_target=is_custom_target, + provider_label=provider_label, + ) + + +def switch_to_custom_provider() -> CustomAutoResult: + """Handle bare '/model custom' — resolve endpoint and auto-detect model. + + Returns a result object; the caller handles persistence and output. + """ + from hermes_cli.runtime_provider import ( + resolve_runtime_provider, + _auto_detect_local_model, + ) + + try: + runtime = resolve_runtime_provider(requested="custom") + except Exception as e: + return CustomAutoResult( + success=False, + error_message=f"Could not resolve custom endpoint: {e}", + ) + + cust_base = runtime.get("base_url", "") + cust_key = runtime.get("api_key", "") + + if not cust_base or "openrouter.ai" in cust_base: + return CustomAutoResult( + success=False, + error_message=( + "No custom endpoint configured. " + "Set model.base_url in config.yaml, or set OPENAI_BASE_URL " + "in .env, or run: hermes setup → Custom OpenAI-compatible endpoint" + ), + ) + + detected_model = _auto_detect_local_model(cust_base) + if not detected_model: + return CustomAutoResult( + success=False, + base_url=cust_base, + api_key=cust_key, + error_message=( + f"Custom endpoint at {cust_base} is reachable but no single " + f"model was auto-detected. Specify the model explicitly: " + f"/model custom:" + ), + ) + + return CustomAutoResult( + success=True, + model=detected_model, + base_url=cust_base, + api_key=cust_key, + )