From 4764e06fdefa65cb75961da64da1ed63c99a0a2a Mon Sep 17 00:00:00 2001 From: Teknium <127238744+teknium1@users.noreply.github.com> Date: Sat, 28 Mar 2026 23:45:53 -0700 Subject: [PATCH] fix(acp): complete session management surface for editor clients (salvage #3501) (#3675) * fix acp adapter session methods * test: stub local command in transcription provider cases --------- Co-authored-by: David Zhang --- acp_adapter/entry.py | 2 +- acp_adapter/server.py | 49 +++++++++++++++-- tests/acp/test_entry.py | 20 +++++++ tests/acp/test_server.py | 71 +++++++++++++++++++++++++ tests/tools/test_transcription_tools.py | 7 ++- 5 files changed, 144 insertions(+), 5 deletions(-) create mode 100644 tests/acp/test_entry.py diff --git a/acp_adapter/entry.py b/acp_adapter/entry.py index fe13ce70..02e44c15 100644 --- a/acp_adapter/entry.py +++ b/acp_adapter/entry.py @@ -74,7 +74,7 @@ def main() -> None: agent = HermesACPAgent() try: - asyncio.run(acp.run_agent(agent)) + asyncio.run(acp.run_agent(agent, use_unstable_protocol=True)) except KeyboardInterrupt: logger.info("Shutting down (KeyboardInterrupt)") except Exception: diff --git a/acp_adapter/server.py b/acp_adapter/server.py index 64c1e518..a5780fb6 100644 --- a/acp_adapter/server.py +++ b/acp_adapter/server.py @@ -25,6 +25,9 @@ from acp.schema import ( NewSessionResponse, PromptResponse, ResumeSessionResponse, + SetSessionConfigOptionResponse, + SetSessionModelResponse, + SetSessionModeResponse, ResourceContentBlock, SessionCapabilities, SessionForkCapabilities, @@ -94,11 +97,14 @@ class HermesACPAgent(acp.Agent): async def initialize( self, - protocol_version: int, + protocol_version: int | None = None, client_capabilities: ClientCapabilities | None = None, client_info: Implementation | None = None, **kwargs: Any, ) -> InitializeResponse: + resolved_protocol_version = ( + protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION + ) provider = detect_provider() auth_methods = None if provider: @@ -111,7 +117,11 @@ class HermesACPAgent(acp.Agent): ] client_name = client_info.name if client_info else "unknown" - logger.info("Initialize from %s (protocol v%s)", client_name, protocol_version) + logger.info( + "Initialize from %s (protocol v%s)", + client_name, + resolved_protocol_version, + ) return InitializeResponse( protocol_version=acp.PROTOCOL_VERSION, @@ -471,7 +481,7 @@ class HermesACPAgent(acp.Agent): async def set_session_model( self, model_id: str, session_id: str, **kwargs: Any - ): + ) -> SetSessionModelResponse | None: """Switch the model for a session (called by ACP protocol).""" state = self.session_manager.get_session(session_id) if state: @@ -489,4 +499,37 @@ class HermesACPAgent(acp.Agent): ) self.session_manager.save_session(session_id) logger.info("Session %s: model switched to %s", session_id, model_id) + return SetSessionModelResponse() + logger.warning("Session %s: model switch requested for missing session", session_id) return None + + async def set_session_mode( + self, mode_id: str, session_id: str, **kwargs: Any + ) -> SetSessionModeResponse | None: + """Persist the editor-requested mode so ACP clients do not fail on mode switches.""" + state = self.session_manager.get_session(session_id) + if state is None: + logger.warning("Session %s: mode switch requested for missing session", session_id) + return None + setattr(state, "mode", mode_id) + self.session_manager.save_session(session_id) + logger.info("Session %s: mode switched to %s", session_id, mode_id) + return SetSessionModeResponse() + + async def set_config_option( + self, config_id: str, session_id: str, value: str, **kwargs: Any + ) -> SetSessionConfigOptionResponse | None: + """Accept ACP config option updates even when Hermes has no typed ACP config surface yet.""" + state = self.session_manager.get_session(session_id) + if state is None: + logger.warning("Session %s: config update requested for missing session", session_id) + return None + + options = getattr(state, "config_options", None) + if not isinstance(options, dict): + options = {} + options[str(config_id)] = value + setattr(state, "config_options", options) + self.session_manager.save_session(session_id) + logger.info("Session %s: config option %s updated", session_id, config_id) + return SetSessionConfigOptionResponse(config_options=[]) diff --git a/tests/acp/test_entry.py b/tests/acp/test_entry.py new file mode 100644 index 00000000..760522c3 --- /dev/null +++ b/tests/acp/test_entry.py @@ -0,0 +1,20 @@ +"""Tests for acp_adapter.entry startup wiring.""" + +import acp + +from acp_adapter import entry + + +def test_main_enables_unstable_protocol(monkeypatch): + calls = {} + + async def fake_run_agent(agent, **kwargs): + calls["kwargs"] = kwargs + + monkeypatch.setattr(entry, "_setup_logging", lambda: None) + monkeypatch.setattr(entry, "_load_env", lambda: None) + monkeypatch.setattr(acp, "run_agent", fake_run_agent) + + entry.main() + + assert calls["kwargs"]["use_unstable_protocol"] is True diff --git a/tests/acp/test_server.py b/tests/acp/test_server.py index 5b9d3de6..fc6d53dd 100644 --- a/tests/acp/test_server.py +++ b/tests/acp/test_server.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, AsyncMock, patch import pytest import acp +from acp.agent.router import build_agent_router from acp.schema import ( AgentCapabilities, AuthenticateResponse, @@ -18,6 +19,8 @@ from acp.schema import ( NewSessionResponse, PromptResponse, ResumeSessionResponse, + SetSessionConfigOptionResponse, + SetSessionModeResponse, SessionInfo, TextContentBlock, Usage, @@ -168,6 +171,74 @@ class TestListAndFork: assert fork_resp.session_id != new_resp.session_id +# --------------------------------------------------------------------------- +# session configuration / model routing +# --------------------------------------------------------------------------- + + +class TestSessionConfiguration: + @pytest.mark.asyncio + async def test_set_session_mode_returns_response(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id) + state = agent.session_manager.get_session(new_resp.session_id) + + assert isinstance(resp, SetSessionModeResponse) + assert getattr(state, "mode", None) == "chat" + + @pytest.mark.asyncio + async def test_set_config_option_returns_response(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + resp = await agent.set_config_option( + config_id="approval_mode", + session_id=new_resp.session_id, + value="auto", + ) + state = agent.session_manager.get_session(new_resp.session_id) + + assert isinstance(resp, SetSessionConfigOptionResponse) + assert getattr(state, "config_options", {}) == {"approval_mode": "auto"} + assert resp.config_options == [] + + @pytest.mark.asyncio + async def test_router_accepts_stable_session_config_methods(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + router = build_agent_router(agent) + + mode_result = await router( + "session/set_mode", + {"modeId": "chat", "sessionId": new_resp.session_id}, + False, + ) + config_result = await router( + "session/set_config_option", + { + "configId": "approval_mode", + "sessionId": new_resp.session_id, + "value": "auto", + }, + False, + ) + + assert mode_result == {} + assert config_result == {"configOptions": []} + + @pytest.mark.asyncio + async def test_router_accepts_unstable_model_switch_when_enabled(self, agent): + new_resp = await agent.new_session(cwd="/tmp") + router = build_agent_router(agent, use_unstable_protocol=True) + + result = await router( + "session/set_model", + {"modelId": "gpt-5.4", "sessionId": new_resp.session_id}, + False, + ) + state = agent.session_manager.get_session(new_resp.session_id) + + assert result == {} + assert state.model == "gpt-5.4" + + # --------------------------------------------------------------------------- # prompt # --------------------------------------------------------------------------- diff --git a/tests/tools/test_transcription_tools.py b/tests/tools/test_transcription_tools.py index b5c9f977..1cdf33ec 100644 --- a/tests/tools/test_transcription_tools.py +++ b/tests/tools/test_transcription_tools.py @@ -96,6 +96,7 @@ class TestGetProviderFallbackPriority: monkeypatch.setenv("GROQ_API_KEY", "gsk-test") monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider assert _get_provider({}) == "groq" @@ -130,9 +131,10 @@ class TestExplicitProviderRespected: def test_explicit_local_no_fallback_to_openai(self, monkeypatch): """GH-1774: provider=local must not silently fall back to openai even when an OpenAI API key is set.""" - monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key-here") + monkeypatch.setenv("OPENAI_API_KEY", "***") monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({"provider": "local"}) @@ -141,6 +143,7 @@ class TestExplicitProviderRespected: def test_explicit_local_no_fallback_to_groq(self, monkeypatch): monkeypatch.setenv("GROQ_API_KEY", "gsk-test") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({"provider": "local"}) @@ -181,6 +184,7 @@ class TestExplicitProviderRespected: monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key") monkeypatch.delenv("GROQ_API_KEY", raising=False) with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider # Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect @@ -191,6 +195,7 @@ class TestExplicitProviderRespected: monkeypatch.setenv("GROQ_API_KEY", "gsk-test") monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key") with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \ + patch("tools.transcription_tools._has_local_command", return_value=False), \ patch("tools.transcription_tools._HAS_OPENAI", True): from tools.transcription_tools import _get_provider result = _get_provider({})