* fix acp adapter session methods * test: stub local command in transcription provider cases --------- Co-authored-by: David Zhang <david.d.zhang@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ def main() -> None:
|
||||
|
||||
agent = HermesACPAgent()
|
||||
try:
|
||||
asyncio.run(acp.run_agent(agent))
|
||||
asyncio.run(acp.run_agent(agent, use_unstable_protocol=True))
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down (KeyboardInterrupt)")
|
||||
except Exception:
|
||||
|
||||
@@ -25,6 +25,9 @@ from acp.schema import (
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModelResponse,
|
||||
SetSessionModeResponse,
|
||||
ResourceContentBlock,
|
||||
SessionCapabilities,
|
||||
SessionForkCapabilities,
|
||||
@@ -94,11 +97,14 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
async def initialize(
|
||||
self,
|
||||
protocol_version: int,
|
||||
protocol_version: int | None = None,
|
||||
client_capabilities: ClientCapabilities | None = None,
|
||||
client_info: Implementation | None = None,
|
||||
**kwargs: Any,
|
||||
) -> InitializeResponse:
|
||||
resolved_protocol_version = (
|
||||
protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION
|
||||
)
|
||||
provider = detect_provider()
|
||||
auth_methods = None
|
||||
if provider:
|
||||
@@ -111,7 +117,11 @@ class HermesACPAgent(acp.Agent):
|
||||
]
|
||||
|
||||
client_name = client_info.name if client_info else "unknown"
|
||||
logger.info("Initialize from %s (protocol v%s)", client_name, protocol_version)
|
||||
logger.info(
|
||||
"Initialize from %s (protocol v%s)",
|
||||
client_name,
|
||||
resolved_protocol_version,
|
||||
)
|
||||
|
||||
return InitializeResponse(
|
||||
protocol_version=acp.PROTOCOL_VERSION,
|
||||
@@ -471,7 +481,7 @@ class HermesACPAgent(acp.Agent):
|
||||
|
||||
async def set_session_model(
|
||||
self, model_id: str, session_id: str, **kwargs: Any
|
||||
):
|
||||
) -> SetSessionModelResponse | None:
|
||||
"""Switch the model for a session (called by ACP protocol)."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state:
|
||||
@@ -489,4 +499,37 @@ class HermesACPAgent(acp.Agent):
|
||||
)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||
return SetSessionModelResponse()
|
||||
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
async def set_session_mode(
|
||||
self, mode_id: str, session_id: str, **kwargs: Any
|
||||
) -> SetSessionModeResponse | None:
|
||||
"""Persist the editor-requested mode so ACP clients do not fail on mode switches."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state is None:
|
||||
logger.warning("Session %s: mode switch requested for missing session", session_id)
|
||||
return None
|
||||
setattr(state, "mode", mode_id)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: mode switched to %s", session_id, mode_id)
|
||||
return SetSessionModeResponse()
|
||||
|
||||
async def set_config_option(
|
||||
self, config_id: str, session_id: str, value: str, **kwargs: Any
|
||||
) -> SetSessionConfigOptionResponse | None:
|
||||
"""Accept ACP config option updates even when Hermes has no typed ACP config surface yet."""
|
||||
state = self.session_manager.get_session(session_id)
|
||||
if state is None:
|
||||
logger.warning("Session %s: config update requested for missing session", session_id)
|
||||
return None
|
||||
|
||||
options = getattr(state, "config_options", None)
|
||||
if not isinstance(options, dict):
|
||||
options = {}
|
||||
options[str(config_id)] = value
|
||||
setattr(state, "config_options", options)
|
||||
self.session_manager.save_session(session_id)
|
||||
logger.info("Session %s: config option %s updated", session_id, config_id)
|
||||
return SetSessionConfigOptionResponse(config_options=[])
|
||||
|
||||
20
tests/acp/test_entry.py
Normal file
20
tests/acp/test_entry.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Tests for acp_adapter.entry startup wiring."""
|
||||
|
||||
import acp
|
||||
|
||||
from acp_adapter import entry
|
||||
|
||||
|
||||
def test_main_enables_unstable_protocol(monkeypatch):
|
||||
calls = {}
|
||||
|
||||
async def fake_run_agent(agent, **kwargs):
|
||||
calls["kwargs"] = kwargs
|
||||
|
||||
monkeypatch.setattr(entry, "_setup_logging", lambda: None)
|
||||
monkeypatch.setattr(entry, "_load_env", lambda: None)
|
||||
monkeypatch.setattr(acp, "run_agent", fake_run_agent)
|
||||
|
||||
entry.main()
|
||||
|
||||
assert calls["kwargs"]["use_unstable_protocol"] is True
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
import acp
|
||||
from acp.agent.router import build_agent_router
|
||||
from acp.schema import (
|
||||
AgentCapabilities,
|
||||
AuthenticateResponse,
|
||||
@@ -18,6 +19,8 @@ from acp.schema import (
|
||||
NewSessionResponse,
|
||||
PromptResponse,
|
||||
ResumeSessionResponse,
|
||||
SetSessionConfigOptionResponse,
|
||||
SetSessionModeResponse,
|
||||
SessionInfo,
|
||||
TextContentBlock,
|
||||
Usage,
|
||||
@@ -168,6 +171,74 @@ class TestListAndFork:
|
||||
assert fork_resp.session_id != new_resp.session_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# session configuration / model routing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSessionConfiguration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_session_mode_returns_response(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, SetSessionModeResponse)
|
||||
assert getattr(state, "mode", None) == "chat"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_config_option_returns_response(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
resp = await agent.set_config_option(
|
||||
config_id="approval_mode",
|
||||
session_id=new_resp.session_id,
|
||||
value="auto",
|
||||
)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert isinstance(resp, SetSessionConfigOptionResponse)
|
||||
assert getattr(state, "config_options", {}) == {"approval_mode": "auto"}
|
||||
assert resp.config_options == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_accepts_stable_session_config_methods(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
router = build_agent_router(agent)
|
||||
|
||||
mode_result = await router(
|
||||
"session/set_mode",
|
||||
{"modeId": "chat", "sessionId": new_resp.session_id},
|
||||
False,
|
||||
)
|
||||
config_result = await router(
|
||||
"session/set_config_option",
|
||||
{
|
||||
"configId": "approval_mode",
|
||||
"sessionId": new_resp.session_id,
|
||||
"value": "auto",
|
||||
},
|
||||
False,
|
||||
)
|
||||
|
||||
assert mode_result == {}
|
||||
assert config_result == {"configOptions": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_accepts_unstable_model_switch_when_enabled(self, agent):
|
||||
new_resp = await agent.new_session(cwd="/tmp")
|
||||
router = build_agent_router(agent, use_unstable_protocol=True)
|
||||
|
||||
result = await router(
|
||||
"session/set_model",
|
||||
{"modelId": "gpt-5.4", "sessionId": new_resp.session_id},
|
||||
False,
|
||||
)
|
||||
state = agent.session_manager.get_session(new_resp.session_id)
|
||||
|
||||
assert result == {}
|
||||
assert state.model == "gpt-5.4"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# prompt
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -96,6 +96,7 @@ class TestGetProviderFallbackPriority:
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
assert _get_provider({}) == "groq"
|
||||
@@ -130,9 +131,10 @@ class TestExplicitProviderRespected:
|
||||
def test_explicit_local_no_fallback_to_openai(self, monkeypatch):
|
||||
"""GH-1774: provider=local must not silently fall back to openai
|
||||
even when an OpenAI API key is set."""
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key-here")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "***")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({"provider": "local"})
|
||||
@@ -141,6 +143,7 @@ class TestExplicitProviderRespected:
|
||||
def test_explicit_local_no_fallback_to_groq(self, monkeypatch):
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({"provider": "local"})
|
||||
@@ -181,6 +184,7 @@ class TestExplicitProviderRespected:
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
# Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect
|
||||
@@ -191,6 +195,7 @@ class TestExplicitProviderRespected:
|
||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||
from tools.transcription_tools import _get_provider
|
||||
result = _get_provider({})
|
||||
|
||||
Reference in New Issue
Block a user