diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 92788988b..64c1e5185 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -383,11 +383,11 @@ class HermesACPAgent(acp.Agent): new_model = args.strip() target_provider = None + current_provider = getattr(state.agent, "provider", None) or "openrouter" # Auto-detect provider for the requested model try: from hermes_cli.models import parse_model_input, detect_provider_for_model - current_provider = getattr(state.agent, "provider", None) or "openrouter" target_provider, new_model = parse_model_input(new_model, current_provider) if target_provider == current_provider: detected = detect_provider_for_model(new_model, current_provider) @@ -401,9 +401,10 @@ class HermesACPAgent(acp.Agent): session_id=state.session_id, cwd=state.cwd, model=new_model, + requested_provider=target_provider or current_provider, ) self.session_manager.save_session(state.session_id) - provider_label = target_provider or getattr(state.agent, "provider", "auto") + provider_label = getattr(state.agent, "provider", None) or target_provider or current_provider logger.info("Session %s: model switched to %s", state.session_id, new_model) return f"Model switched to: {new_model}\nProvider: {provider_label}" @@ -475,10 +476,16 @@ class HermesACPAgent(acp.Agent): state = self.session_manager.get_session(session_id) if state: state.model = model_id + current_provider = getattr(state.agent, "provider", None) + current_base_url = getattr(state.agent, "base_url", None) + current_api_mode = getattr(state.agent, "api_mode", None) state.agent = self.session_manager._make_agent( session_id=session_id, cwd=state.cwd, model=model_id, + requested_provider=current_provider, + base_url=current_base_url, + api_mode=current_api_mode, ) self.session_manager.save_session(session_id) logger.info("Session %s: model switched to %s", session_id, model_id) diff --git a/acp_adapter/session.py b/acp_adapter/session.py index 01b2ee479..629b086f9 100644 --- a/acp_adapter/session.py +++ b/acp_adapter/session.py @@ -270,7 +270,17 @@ class SessionManager: # Ensure model is a plain string (not a MagicMock or other proxy). model_str = str(state.model) if state.model else None - cwd_json = json.dumps({"cwd": state.cwd}) + session_meta = {"cwd": state.cwd} + provider = getattr(state.agent, "provider", None) + base_url = getattr(state.agent, "base_url", None) + api_mode = getattr(state.agent, "api_mode", None) + if isinstance(provider, str) and provider.strip(): + session_meta["provider"] = provider.strip() + if isinstance(base_url, str) and base_url.strip(): + session_meta["base_url"] = base_url.strip() + if isinstance(api_mode, str) and api_mode.strip(): + session_meta["api_mode"] = api_mode.strip() + cwd_json = json.dumps(session_meta) try: # Ensure the session record exists. @@ -331,10 +341,18 @@ class SessionManager: # Extract cwd from model_config. cwd = "." + requested_provider = row.get("billing_provider") + restored_base_url = row.get("billing_base_url") + restored_api_mode = None mc = row.get("model_config") if mc: try: - cwd = json.loads(mc).get("cwd", ".") + meta = json.loads(mc) + if isinstance(meta, dict): + cwd = meta.get("cwd", ".") + requested_provider = meta.get("provider") or requested_provider + restored_base_url = meta.get("base_url") or restored_base_url + restored_api_mode = meta.get("api_mode") or restored_api_mode except (json.JSONDecodeError, TypeError): pass @@ -348,7 +366,14 @@ class SessionManager: history = [] try: - agent = self._make_agent(session_id=session_id, cwd=cwd, model=model) + agent = self._make_agent( + session_id=session_id, + cwd=cwd, + model=model, + requested_provider=requested_provider, + base_url=restored_base_url, + api_mode=restored_api_mode, + ) except Exception: logger.warning("Failed to recreate agent for ACP session %s", session_id, exc_info=True) return None @@ -386,6 +411,9 @@ class SessionManager: session_id: str, cwd: str, model: str | None = None, + requested_provider: str | None = None, + base_url: str | None = None, + api_mode: str | None = None, ): if self._agent_factory is not None: return self._agent_factory() @@ -397,10 +425,10 @@ class SessionManager: config = load_config() model_cfg = config.get("model") default_model = "anthropic/claude-opus-4.6" - requested_provider = None + config_provider = None if isinstance(model_cfg, dict): default_model = str(model_cfg.get("default") or default_model) - requested_provider = model_cfg.get("provider") + config_provider = model_cfg.get("provider") elif isinstance(model_cfg, str) and model_cfg.strip(): default_model = model_cfg.strip() @@ -413,12 +441,12 @@ class SessionManager: } try: - runtime = resolve_runtime_provider(requested=requested_provider) + runtime = resolve_runtime_provider(requested=requested_provider or config_provider) kwargs.update( { "provider": runtime.get("provider"), - "api_mode": runtime.get("api_mode"), - "base_url": runtime.get("base_url"), + "api_mode": api_mode or runtime.get("api_mode"), + "base_url": base_url or runtime.get("base_url"), "api_key": runtime.get("api_key"), "command": runtime.get("command"), "args": list(runtime.get("args") or []), diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index 341f4b758..5b9d3de62 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -2,6 +2,7 @@ import asyncio import os +from types import SimpleNamespace from unittest.mock import MagicMock, AsyncMock, patch import pytest @@ -23,6 +24,7 @@ from acp.schema import ( ) from acp_adapter.server import HermesACPAgent, HERMES_VERSION from acp_adapter.session import SessionManager +from hermes_state import SessionDB @pytest.fixture() @@ -389,3 +391,46 @@ class TestSlashCommands: resp = await agent.prompt(prompt=prompt, session_id=new_resp.session_id) assert resp.stop_reason == "end_turn" + + def test_model_switch_uses_requested_provider(self, tmp_path, monkeypatch): + """`/model provider:model` should rebuild the ACP agent on that provider.""" + runtime_calls = [] + + def fake_resolve_runtime_provider(requested=None, **kwargs): + runtime_calls.append(requested) + provider = requested or "openrouter" + return { + "provider": provider, + "api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions", + "base_url": f"https://{provider}.example/v1", + "api_key": f"{provider}-key", + "command": None, + "args": [], + } + + def fake_agent(**kwargs): + return SimpleNamespace( + model=kwargs.get("model"), + provider=kwargs.get("provider"), + base_url=kwargs.get("base_url"), + api_mode=kwargs.get("api_mode"), + ) + + monkeypatch.setattr("hermes_cli.config.load_config", lambda: { + "model": {"provider": "openrouter", "default": "openrouter/gpt-5"} + }) + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve_runtime_provider, + ) + manager = SessionManager(db=SessionDB(tmp_path / "state.db")) + + with patch("run_agent.AIAgent", side_effect=fake_agent): + acp_agent = HermesACPAgent(session_manager=manager) + state = manager.create_session(cwd="/tmp") + result = acp_agent._cmd_model("anthropic:claude-sonnet-4-6", state) + + assert "Provider: anthropic" in result + assert state.agent.provider == "anthropic" + assert state.agent.base_url == "https://anthropic.example/v1" + assert runtime_calls[-1] == "anthropic" diff --git a/tests/acp/test_session.py b/tests/acp/test_session.py index 43d9a7229..1a7a9da51 100644 --- a/tests/acp/test_session.py +++ b/tests/acp/test_session.py @@ -1,8 +1,9 @@ """Tests for acp_adapter.session — SessionManager and SessionState.""" import json +from types import SimpleNamespace import pytest -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from acp_adapter.session import SessionManager, SessionState from hermes_state import SessionDB @@ -281,3 +282,50 @@ class TestPersistence: assert len(restored.history) == 2 assert restored.history[0].get("tool_calls") is not None assert restored.history[1].get("tool_call_id") == "tc_1" + + def test_restore_preserves_persisted_provider_snapshot(self, tmp_path, monkeypatch): + """Restored ACP sessions should keep their original runtime provider.""" + runtime_choice = {"provider": "anthropic"} + + def fake_resolve_runtime_provider(requested=None, **kwargs): + provider = requested or runtime_choice["provider"] + return { + "provider": provider, + "api_mode": "anthropic_messages" if provider == "anthropic" else "chat_completions", + "base_url": f"https://{provider}.example/v1", + "api_key": f"{provider}-key", + "command": None, + "args": [], + } + + def fake_agent(**kwargs): + return SimpleNamespace( + model=kwargs.get("model"), + provider=kwargs.get("provider"), + base_url=kwargs.get("base_url"), + api_mode=kwargs.get("api_mode"), + ) + + monkeypatch.setattr("hermes_cli.config.load_config", lambda: { + "model": {"provider": runtime_choice["provider"], "default": "test-model"} + }) + monkeypatch.setattr( + "hermes_cli.runtime_provider.resolve_runtime_provider", + fake_resolve_runtime_provider, + ) + db = SessionDB(tmp_path / "state.db") + + with patch("run_agent.AIAgent", side_effect=fake_agent): + manager = SessionManager(db=db) + state = manager.create_session(cwd="/work") + manager.save_session(state.session_id) + + with manager._lock: + del manager._sessions[state.session_id] + + runtime_choice["provider"] = "openrouter" + restored = manager.get_session(state.session_id) + + assert restored is not None + assert restored.agent.provider == "anthropic" + assert restored.agent.base_url == "https://anthropic.example/v1"