diff --git a/cli.py b/cli.py index a49176be7..7baf0365c 100755 --- a/cli.py +++ b/cli.py @@ -2061,29 +2061,43 @@ class HermesCLI: # Use original case so model names like "Anthropic/Claude-Opus-4" are preserved parts = cmd_original.split(maxsplit=1) if len(parts) > 1: - new_model = parts[1].strip() - from hermes_cli.auth import resolve_provider - from hermes_cli.models import validate_requested_model + from hermes_cli.models import ( + parse_model_input, + validate_requested_model, + _PROVIDER_LABELS, + ) - try: - provider_for_validation = resolve_provider( - self.requested_provider, - explicit_api_key=self._explicit_api_key, - explicit_base_url=self._explicit_base_url, - ) - except Exception: - provider_for_validation = self.provider or self.requested_provider + raw_input = parts[1].strip() + + # Parse provider:model syntax (e.g. "openrouter:anthropic/claude-sonnet-4.5") + current_provider = self.provider or self.requested_provider or "openrouter" + target_provider, new_model = parse_model_input(raw_input, current_provider) + provider_changed = target_provider != current_provider + + # If provider is changing, re-resolve credentials for the new provider + api_key_for_probe = self.api_key + base_url_for_probe = self.base_url + if provider_changed: + try: + from hermes_cli.runtime_provider import resolve_runtime_provider + runtime = resolve_runtime_provider(requested=target_provider) + api_key_for_probe = runtime.get("api_key", "") + base_url_for_probe = runtime.get("base_url", "") + except Exception as e: + provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) + print(f"(>_<) Could not resolve credentials for provider '{provider_label}': {e}") + print(f"(^_^) Current model unchanged: {self.model}") + return True try: validation = validate_requested_model( new_model, - provider_for_validation, - api_key=self.api_key, - base_url=self.base_url, + target_provider, + api_key=api_key_for_probe, + base_url=base_url_for_probe, ) except Exception: - # Validation itself failed — fall back to old behavior (accept + save) validation = {"accepted": True, "persist": True, "recognized": False, "message": None} if not validation.get("accepted"): @@ -2093,20 +2107,49 @@ class HermesCLI: self.model = new_model self.agent = None # Force re-init + if provider_changed: + self.requested_provider = target_provider + self.provider = target_provider + self.api_key = api_key_for_probe + self.base_url = base_url_for_probe + + provider_label = _PROVIDER_LABELS.get(target_provider, target_provider) + provider_note = f" [provider: {provider_label}]" if provider_changed else "" + if validation.get("persist"): - if save_config_value("model.default", new_model): - print(f"(^_^)b Model changed to: {new_model} (saved to config)") + saved_model = save_config_value("model.default", new_model) + if provider_changed: + save_config_value("model.provider", target_provider) + if saved_model: + print(f"(^_^)b Model changed to: {new_model}{provider_note} (saved to config)") else: - print(f"(^_^) Model changed to: {new_model} (session only)") + print(f"(^_^) Model changed to: {new_model}{provider_note} (session only)") else: - print(f"(^_^) Model changed to: {new_model} (session only)") + print(f"(^_^) Model changed to: {new_model}{provider_note} (session only)") message = validation.get("message") if message: print(f" Warning: {message}") else: - print(f"Current model: {self.model}") - print(" Usage: /model to change") + from hermes_cli.models import curated_models_for_provider, _PROVIDER_LABELS + provider_label = _PROVIDER_LABELS.get( + self.provider or "openrouter", + self.provider or "openrouter", + ) + print(f"\n Current model: {self.model}") + print(f" Current provider: {provider_label}") + print() + curated = curated_models_for_provider(self.provider) + if curated: + print(f" Available models ({provider_label}):") + for mid, desc in curated: + marker = " ←" if mid == self.model else "" + label = f" {desc}" if desc else "" + print(f" {mid}{label}{marker}") + print() + print(" Usage: /model ") + print(" /model provider:model-name (to switch provider)") + print(" Example: /model openrouter:anthropic/claude-sonnet-4.5") elif cmd_lower.startswith("/prompt"): # Use original case so prompt text isn't lowercased self._handle_prompt_command(cmd_original) diff --git a/hermes_cli/models.py b/hermes_cli/models.py index cbcfc405f..c12dec31d 100644 --- a/hermes_cli/models.py +++ b/hermes_cli/models.py @@ -91,6 +91,38 @@ def menu_labels() -> list[str]: return labels +def parse_model_input(raw: str, current_provider: str) -> tuple[str, str]: + """Parse ``/model`` input into ``(provider, model)``. + + Supports ``provider:model`` syntax to switch providers at runtime:: + + openrouter:anthropic/claude-sonnet-4.5 → ("openrouter", "anthropic/claude-sonnet-4.5") + nous:hermes-3 → ("nous", "hermes-3") + anthropic/claude-sonnet-4.5 → (current_provider, "anthropic/claude-sonnet-4.5") + gpt-5.4 → (current_provider, "gpt-5.4") + + Returns ``(provider, model)`` where *provider* is either the explicit + provider from the input or *current_provider* if none was specified. + """ + stripped = raw.strip() + colon = stripped.find(":") + if colon > 0: + provider_part = stripped[:colon].strip().lower() + model_part = stripped[colon + 1:].strip() + if provider_part and model_part: + return (normalize_provider(provider_part), model_part) + return (current_provider, stripped) + + +def curated_models_for_provider(provider: Optional[str]) -> list[tuple[str, str]]: + """Return ``(model_id, description)`` tuples for a provider's curated list.""" + normalized = normalize_provider(provider) + if normalized == "openrouter": + return list(OPENROUTER_MODELS) + models = _PROVIDER_MODELS.get(normalized, []) + return [(m, "") for m in models] + + def normalize_provider(provider: Optional[str]) -> str: """Normalize provider aliases to Hermes' canonical provider ids.""" normalized = (provider or "openrouter").strip().lower() diff --git a/tests/hermes_cli/test_model_validation.py b/tests/hermes_cli/test_model_validation.py index d473a411c..36ef37d18 100644 --- a/tests/hermes_cli/test_model_validation.py +++ b/tests/hermes_cli/test_model_validation.py @@ -3,8 +3,10 @@ from unittest.mock import patch from hermes_cli.models import ( + curated_models_for_provider, fetch_api_models, normalize_provider, + parse_model_input, provider_model_ids, validate_requested_model, ) @@ -12,7 +14,6 @@ from hermes_cli.models import ( # -- helpers ----------------------------------------------------------------- -# Simulated API model list for mocking fetch_api_models FAKE_API_MODELS = [ "anthropic/claude-opus-4.6", "anthropic/claude-sonnet-4.5", @@ -28,6 +29,61 @@ def _validate(model, provider="openrouter", api_models=FAKE_API_MODELS, **kw): return validate_requested_model(model, provider, **kw) +# -- parse_model_input ------------------------------------------------------- + +class TestParseModelInput: + def test_plain_model_keeps_current_provider(self): + provider, model = parse_model_input("anthropic/claude-sonnet-4.5", "openrouter") + assert provider == "openrouter" + assert model == "anthropic/claude-sonnet-4.5" + + def test_provider_colon_model_switches_provider(self): + provider, model = parse_model_input("openrouter:anthropic/claude-sonnet-4.5", "nous") + assert provider == "openrouter" + assert model == "anthropic/claude-sonnet-4.5" + + def test_provider_alias_resolved(self): + provider, model = parse_model_input("glm:glm-5", "openrouter") + assert provider == "zai" + assert model == "glm-5" + + def test_no_slash_no_colon_keeps_provider(self): + provider, model = parse_model_input("gpt-5.4", "openrouter") + assert provider == "openrouter" + assert model == "gpt-5.4" + + def test_nous_provider_switch(self): + provider, model = parse_model_input("nous:hermes-3", "openrouter") + assert provider == "nous" + assert model == "hermes-3" + + def test_empty_model_after_colon_keeps_current(self): + provider, model = parse_model_input("openrouter:", "nous") + assert provider == "nous" + assert model == "openrouter:" + + def test_colon_at_start_keeps_current(self): + provider, model = parse_model_input(":something", "openrouter") + assert provider == "openrouter" + assert model == ":something" + + +# -- curated_models_for_provider --------------------------------------------- + +class TestCuratedModelsForProvider: + def test_openrouter_returns_curated_list(self): + models = curated_models_for_provider("openrouter") + assert len(models) > 0 + assert any("claude" in m[0] for m in models) + + def test_zai_returns_glm_models(self): + models = curated_models_for_provider("zai") + assert any("glm" in m[0] for m in models) + + def test_unknown_provider_returns_empty(self): + assert curated_models_for_provider("totally-unknown") == [] + + # -- normalize_provider ------------------------------------------------------ class TestNormalizeProvider: @@ -37,21 +93,11 @@ class TestNormalizeProvider: def test_known_aliases(self): assert normalize_provider("glm") == "zai" - assert normalize_provider("z-ai") == "zai" - assert normalize_provider("z.ai") == "zai" - assert normalize_provider("zhipu") == "zai" assert normalize_provider("kimi") == "kimi-coding" assert normalize_provider("moonshot") == "kimi-coding" - assert normalize_provider("minimax-china") == "minimax-cn" - - def test_canonical_ids_pass_through(self): - assert normalize_provider("openrouter") == "openrouter" - assert normalize_provider("nous") == "nous" - assert normalize_provider("openai-codex") == "openai-codex" def test_case_insensitive(self): assert normalize_provider("OpenRouter") == "openrouter" - assert normalize_provider("GLM") == "zai" # -- provider_model_ids ------------------------------------------------------ @@ -66,11 +112,7 @@ class TestProviderModelIds: assert provider_model_ids("some-unknown-provider") == [] def test_zai_returns_glm_models(self): - ids = provider_model_ids("zai") - assert "glm-5" in ids - - def test_alias_resolves_correctly(self): - assert provider_model_ids("glm") == provider_model_ids("zai") + assert "glm-5" in provider_model_ids("zai") # -- fetch_api_models -------------------------------------------------------- @@ -78,14 +120,13 @@ class TestProviderModelIds: class TestFetchApiModels: def test_returns_none_when_no_base_url(self): assert fetch_api_models("key", None) is None - assert fetch_api_models("key", "") is None def test_returns_none_on_network_error(self): with patch("hermes_cli.models.urllib.request.urlopen", side_effect=Exception("timeout")): assert fetch_api_models("key", "https://example.com/v1") is None -# -- validate_requested_model — format checks (no API needed) ---------------- +# -- validate — format checks ----------------------------------------------- class TestValidateFormatChecks: def test_empty_model_rejected(self): @@ -96,15 +137,12 @@ class TestValidateFormatChecks: def test_whitespace_only_rejected(self): result = _validate(" ") assert result["accepted"] is False - assert "empty" in result["message"] def test_model_with_spaces_rejected(self): result = _validate("anthropic/ claude-opus") assert result["accepted"] is False - assert "spaces" in result["message"].lower() def test_no_slash_model_still_probes_api(self): - """Models without '/' should still be checked via API (not all providers need it).""" result = _validate("gpt-5.4", api_models=["gpt-5.4", "gpt-5.4-pro"]) assert result["accepted"] is True assert result["persist"] is True @@ -112,80 +150,60 @@ class TestValidateFormatChecks: def test_no_slash_model_rejected_if_not_in_api(self): result = _validate("gpt-5.4", api_models=["openai/gpt-5.4"]) assert result["accepted"] is False - assert "not a valid model" in result["message"] -# -- validate_requested_model — API probe found model ------------------------ +# -- validate — API found ---------------------------------------------------- class TestValidateApiFound: - def test_model_found_in_api_is_accepted_and_persisted(self): + def test_model_found_in_api(self): result = _validate("anthropic/claude-opus-4.6") assert result["accepted"] is True assert result["persist"] is True assert result["recognized"] is True - assert result["message"] is None - def test_model_found_in_api_for_custom_endpoint(self): + def test_model_found_for_custom_endpoint(self): result = _validate( - "my-model", - provider="openrouter", - api_models=["my-model", "other-model"], - base_url="http://localhost:11434/v1", + "my-model", provider="openrouter", + api_models=["my-model"], base_url="http://localhost:11434/v1", ) assert result["accepted"] is True assert result["persist"] is True -# -- validate_requested_model — API probe model not found -------------------- +# -- validate — API not found ------------------------------------------------ class TestValidateApiNotFound: - def test_model_not_in_api_is_rejected(self): + def test_model_not_in_api_rejected(self): result = _validate("anthropic/claude-nonexistent") assert result["accepted"] is False - assert result["persist"] is False assert "not a valid model" in result["message"] def test_rejection_includes_suggestions(self): - result = _validate("anthropic/claude-opus-4.5") # close to claude-opus-4.6 + result = _validate("anthropic/claude-opus-4.5") assert result["accepted"] is False assert "Did you mean" in result["message"] - def test_completely_wrong_model_rejected(self): - result = _validate("totally/fake-model-xyz") - assert result["accepted"] is False - assert "not a valid model" in result["message"] - -# -- validate_requested_model — API unreachable (fallback) ------------------- +# -- validate — API unreachable (fallback) ----------------------------------- class TestValidateApiFallback: def test_known_catalog_model_accepted_when_api_down(self): - """If API is unreachable, fall back to hardcoded catalog.""" result = _validate("anthropic/claude-opus-4.6", api_models=None) assert result["accepted"] is True assert result["persist"] is True - assert result["recognized"] is True - def test_unknown_model_is_session_only_when_api_down(self): + def test_unknown_model_session_only_when_api_down(self): result = _validate("anthropic/claude-next-gen", api_models=None) assert result["accepted"] is True assert result["persist"] is False - assert "Could not validate" in result["message"] assert "session only" in result["message"].lower() def test_zai_known_model_accepted_when_api_down(self): result = _validate("glm-5", provider="zai", api_models=None) assert result["accepted"] is True assert result["persist"] is True - assert result["recognized"] is True - - def test_zai_unknown_model_session_only_when_api_down(self): - result = _validate("glm-99", provider="zai", api_models=None) - assert result["accepted"] is True - assert result["persist"] is False def test_unknown_provider_session_only_when_api_down(self): result = _validate("some-model", provider="totally-unknown", api_models=None) assert result["accepted"] is True assert result["persist"] is False - assert result["message"] is not None diff --git a/tests/test_cli_model_command.py b/tests/test_cli_model_command.py index a43b96379..13c4f0f22 100644 --- a/tests/test_cli_model_command.py +++ b/tests/test_cli_model_command.py @@ -1,6 +1,6 @@ """Regression tests for the `/model` slash command in the interactive CLI.""" -from unittest.mock import patch +from unittest.mock import patch, MagicMock from cli import HermesCLI @@ -21,8 +21,7 @@ class TestModelCommand: def test_valid_model_from_api_saved_to_config(self, capsys): cli_obj = self._make_cli() - with patch("hermes_cli.auth.resolve_provider", return_value="openrouter"), \ - patch("hermes_cli.models.fetch_api_models", + with patch("hermes_cli.models.fetch_api_models", return_value=["anthropic/claude-sonnet-4.5", "openai/gpt-5.4"]), \ patch("cli.save_config_value", return_value=True) as save_mock: cli_obj.process_command("/model anthropic/claude-sonnet-4.5") @@ -30,60 +29,51 @@ class TestModelCommand: output = capsys.readouterr().out assert "saved to config" in output assert cli_obj.model == "anthropic/claude-sonnet-4.5" - assert cli_obj.agent is None save_mock.assert_called_once_with("model.default", "anthropic/claude-sonnet-4.5") def test_invalid_model_from_api_is_rejected(self, capsys): cli_obj = self._make_cli() - with patch("hermes_cli.auth.resolve_provider", return_value="openrouter"), \ - patch("hermes_cli.models.fetch_api_models", + with patch("hermes_cli.models.fetch_api_models", return_value=["anthropic/claude-opus-4.6"]), \ patch("cli.save_config_value") as save_mock: cli_obj.process_command("/model anthropic/fake-model") output = capsys.readouterr().out assert "not a valid model" in output - assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged - assert cli_obj.agent is not None # not reset + assert cli_obj.model == "anthropic/claude-opus-4.6" save_mock.assert_not_called() def test_model_when_api_unreachable_falls_back_session_only(self, capsys): cli_obj = self._make_cli() - with patch("hermes_cli.auth.resolve_provider", return_value="openrouter"), \ - patch("hermes_cli.models.fetch_api_models", return_value=None), \ + with patch("hermes_cli.models.fetch_api_models", return_value=None), \ patch("cli.save_config_value") as save_mock: cli_obj.process_command("/model anthropic/claude-sonnet-next") output = capsys.readouterr().out assert "session only" in output assert cli_obj.model == "anthropic/claude-sonnet-next" - assert cli_obj.agent is None save_mock.assert_not_called() def test_no_slash_model_probes_api_and_rejects(self, capsys): - """Model without '/' is still probed via API — not rejected on format alone.""" cli_obj = self._make_cli() - with patch("hermes_cli.auth.resolve_provider", return_value="openrouter"), \ - patch("hermes_cli.models.fetch_api_models", + with patch("hermes_cli.models.fetch_api_models", return_value=["openai/gpt-5.4"]) as fetch_mock, \ patch("cli.save_config_value") as save_mock: cli_obj.process_command("/model gpt-5.4") output = capsys.readouterr().out assert "not a valid model" in output - assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged - fetch_mock.assert_called_once() # API was probed + assert cli_obj.model == "anthropic/claude-opus-4.6" + fetch_mock.assert_called_once() save_mock.assert_not_called() def test_validation_crash_falls_back_to_save(self, capsys): - """If validate_requested_model throws, /model should still work (old behavior).""" cli_obj = self._make_cli() - with patch("hermes_cli.auth.resolve_provider", return_value="openrouter"), \ - patch("hermes_cli.models.validate_requested_model", + with patch("hermes_cli.models.validate_requested_model", side_effect=RuntimeError("boom")), \ patch("cli.save_config_value", return_value=True) as save_mock: cli_obj.process_command("/model anthropic/claude-sonnet-4.5") @@ -99,4 +89,42 @@ class TestModelCommand: output = capsys.readouterr().out assert "anthropic/claude-opus-4.6" in output - assert "Usage" in output + assert "OpenRouter" in output + assert "Available models" in output + assert "provider:model-name" in output + + # -- provider switching tests ------------------------------------------- + + def test_provider_colon_model_switches_provider(self, capsys): + cli_obj = self._make_cli() + + with patch("hermes_cli.runtime_provider.resolve_runtime_provider", return_value={ + "provider": "zai", + "api_key": "zai-key", + "base_url": "https://api.z.ai/api/paas/v4", + }), \ + patch("hermes_cli.models.fetch_api_models", + return_value=["glm-5", "glm-4.7"]), \ + patch("cli.save_config_value", return_value=True) as save_mock: + cli_obj.process_command("/model zai:glm-5") + + output = capsys.readouterr().out + assert "glm-5" in output + assert "provider:" in output.lower() or "Z.AI" in output + assert cli_obj.model == "glm-5" + assert cli_obj.provider == "zai" + assert cli_obj.base_url == "https://api.z.ai/api/paas/v4" + # Both model and provider should be saved + assert save_mock.call_count == 2 + + def test_provider_switch_fails_on_bad_credentials(self, capsys): + cli_obj = self._make_cli() + + with patch("hermes_cli.runtime_provider.resolve_runtime_provider", + side_effect=Exception("No API key found")): + cli_obj.process_command("/model nous:hermes-3") + + output = capsys.readouterr().out + assert "Could not resolve credentials" in output + assert cli_obj.model == "anthropic/claude-opus-4.6" # unchanged + assert cli_obj.provider == "openrouter" # unchanged