* fix acp adapter session methods * test: stub local command in transcription provider cases --------- Co-authored-by: David Zhang <david.d.zhang@gmail.com>
This commit is contained in:
@@ -74,7 +74,7 @@ def main() -> None:
|
|||||||
|
|
||||||
agent = HermesACPAgent()
|
agent = HermesACPAgent()
|
||||||
try:
|
try:
|
||||||
asyncio.run(acp.run_agent(agent))
|
asyncio.run(acp.run_agent(agent, use_unstable_protocol=True))
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Shutting down (KeyboardInterrupt)")
|
logger.info("Shutting down (KeyboardInterrupt)")
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ from acp.schema import (
|
|||||||
NewSessionResponse,
|
NewSessionResponse,
|
||||||
PromptResponse,
|
PromptResponse,
|
||||||
ResumeSessionResponse,
|
ResumeSessionResponse,
|
||||||
|
SetSessionConfigOptionResponse,
|
||||||
|
SetSessionModelResponse,
|
||||||
|
SetSessionModeResponse,
|
||||||
ResourceContentBlock,
|
ResourceContentBlock,
|
||||||
SessionCapabilities,
|
SessionCapabilities,
|
||||||
SessionForkCapabilities,
|
SessionForkCapabilities,
|
||||||
@@ -94,11 +97,14 @@ class HermesACPAgent(acp.Agent):
|
|||||||
|
|
||||||
async def initialize(
|
async def initialize(
|
||||||
self,
|
self,
|
||||||
protocol_version: int,
|
protocol_version: int | None = None,
|
||||||
client_capabilities: ClientCapabilities | None = None,
|
client_capabilities: ClientCapabilities | None = None,
|
||||||
client_info: Implementation | None = None,
|
client_info: Implementation | None = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> InitializeResponse:
|
) -> InitializeResponse:
|
||||||
|
resolved_protocol_version = (
|
||||||
|
protocol_version if isinstance(protocol_version, int) else acp.PROTOCOL_VERSION
|
||||||
|
)
|
||||||
provider = detect_provider()
|
provider = detect_provider()
|
||||||
auth_methods = None
|
auth_methods = None
|
||||||
if provider:
|
if provider:
|
||||||
@@ -111,7 +117,11 @@ class HermesACPAgent(acp.Agent):
|
|||||||
]
|
]
|
||||||
|
|
||||||
client_name = client_info.name if client_info else "unknown"
|
client_name = client_info.name if client_info else "unknown"
|
||||||
logger.info("Initialize from %s (protocol v%s)", client_name, protocol_version)
|
logger.info(
|
||||||
|
"Initialize from %s (protocol v%s)",
|
||||||
|
client_name,
|
||||||
|
resolved_protocol_version,
|
||||||
|
)
|
||||||
|
|
||||||
return InitializeResponse(
|
return InitializeResponse(
|
||||||
protocol_version=acp.PROTOCOL_VERSION,
|
protocol_version=acp.PROTOCOL_VERSION,
|
||||||
@@ -471,7 +481,7 @@ class HermesACPAgent(acp.Agent):
|
|||||||
|
|
||||||
async def set_session_model(
|
async def set_session_model(
|
||||||
self, model_id: str, session_id: str, **kwargs: Any
|
self, model_id: str, session_id: str, **kwargs: Any
|
||||||
):
|
) -> SetSessionModelResponse | None:
|
||||||
"""Switch the model for a session (called by ACP protocol)."""
|
"""Switch the model for a session (called by ACP protocol)."""
|
||||||
state = self.session_manager.get_session(session_id)
|
state = self.session_manager.get_session(session_id)
|
||||||
if state:
|
if state:
|
||||||
@@ -489,4 +499,37 @@ class HermesACPAgent(acp.Agent):
|
|||||||
)
|
)
|
||||||
self.session_manager.save_session(session_id)
|
self.session_manager.save_session(session_id)
|
||||||
logger.info("Session %s: model switched to %s", session_id, model_id)
|
logger.info("Session %s: model switched to %s", session_id, model_id)
|
||||||
|
return SetSessionModelResponse()
|
||||||
|
logger.warning("Session %s: model switch requested for missing session", session_id)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def set_session_mode(
|
||||||
|
self, mode_id: str, session_id: str, **kwargs: Any
|
||||||
|
) -> SetSessionModeResponse | None:
|
||||||
|
"""Persist the editor-requested mode so ACP clients do not fail on mode switches."""
|
||||||
|
state = self.session_manager.get_session(session_id)
|
||||||
|
if state is None:
|
||||||
|
logger.warning("Session %s: mode switch requested for missing session", session_id)
|
||||||
|
return None
|
||||||
|
setattr(state, "mode", mode_id)
|
||||||
|
self.session_manager.save_session(session_id)
|
||||||
|
logger.info("Session %s: mode switched to %s", session_id, mode_id)
|
||||||
|
return SetSessionModeResponse()
|
||||||
|
|
||||||
|
async def set_config_option(
|
||||||
|
self, config_id: str, session_id: str, value: str, **kwargs: Any
|
||||||
|
) -> SetSessionConfigOptionResponse | None:
|
||||||
|
"""Accept ACP config option updates even when Hermes has no typed ACP config surface yet."""
|
||||||
|
state = self.session_manager.get_session(session_id)
|
||||||
|
if state is None:
|
||||||
|
logger.warning("Session %s: config update requested for missing session", session_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
options = getattr(state, "config_options", None)
|
||||||
|
if not isinstance(options, dict):
|
||||||
|
options = {}
|
||||||
|
options[str(config_id)] = value
|
||||||
|
setattr(state, "config_options", options)
|
||||||
|
self.session_manager.save_session(session_id)
|
||||||
|
logger.info("Session %s: config option %s updated", session_id, config_id)
|
||||||
|
return SetSessionConfigOptionResponse(config_options=[])
|
||||||
|
|||||||
20
tests/acp/test_entry.py
Normal file
20
tests/acp/test_entry.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""Tests for acp_adapter.entry startup wiring."""
|
||||||
|
|
||||||
|
import acp
|
||||||
|
|
||||||
|
from acp_adapter import entry
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_enables_unstable_protocol(monkeypatch):
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
async def fake_run_agent(agent, **kwargs):
|
||||||
|
calls["kwargs"] = kwargs
|
||||||
|
|
||||||
|
monkeypatch.setattr(entry, "_setup_logging", lambda: None)
|
||||||
|
monkeypatch.setattr(entry, "_load_env", lambda: None)
|
||||||
|
monkeypatch.setattr(acp, "run_agent", fake_run_agent)
|
||||||
|
|
||||||
|
entry.main()
|
||||||
|
|
||||||
|
assert calls["kwargs"]["use_unstable_protocol"] is True
|
||||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, AsyncMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import acp
|
import acp
|
||||||
|
from acp.agent.router import build_agent_router
|
||||||
from acp.schema import (
|
from acp.schema import (
|
||||||
AgentCapabilities,
|
AgentCapabilities,
|
||||||
AuthenticateResponse,
|
AuthenticateResponse,
|
||||||
@@ -18,6 +19,8 @@ from acp.schema import (
|
|||||||
NewSessionResponse,
|
NewSessionResponse,
|
||||||
PromptResponse,
|
PromptResponse,
|
||||||
ResumeSessionResponse,
|
ResumeSessionResponse,
|
||||||
|
SetSessionConfigOptionResponse,
|
||||||
|
SetSessionModeResponse,
|
||||||
SessionInfo,
|
SessionInfo,
|
||||||
TextContentBlock,
|
TextContentBlock,
|
||||||
Usage,
|
Usage,
|
||||||
@@ -168,6 +171,74 @@ class TestListAndFork:
|
|||||||
assert fork_resp.session_id != new_resp.session_id
|
assert fork_resp.session_id != new_resp.session_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# session configuration / model routing
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSessionConfiguration:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_session_mode_returns_response(self, agent):
|
||||||
|
new_resp = await agent.new_session(cwd="/tmp")
|
||||||
|
resp = await agent.set_session_mode(mode_id="chat", session_id=new_resp.session_id)
|
||||||
|
state = agent.session_manager.get_session(new_resp.session_id)
|
||||||
|
|
||||||
|
assert isinstance(resp, SetSessionModeResponse)
|
||||||
|
assert getattr(state, "mode", None) == "chat"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_set_config_option_returns_response(self, agent):
|
||||||
|
new_resp = await agent.new_session(cwd="/tmp")
|
||||||
|
resp = await agent.set_config_option(
|
||||||
|
config_id="approval_mode",
|
||||||
|
session_id=new_resp.session_id,
|
||||||
|
value="auto",
|
||||||
|
)
|
||||||
|
state = agent.session_manager.get_session(new_resp.session_id)
|
||||||
|
|
||||||
|
assert isinstance(resp, SetSessionConfigOptionResponse)
|
||||||
|
assert getattr(state, "config_options", {}) == {"approval_mode": "auto"}
|
||||||
|
assert resp.config_options == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_accepts_stable_session_config_methods(self, agent):
|
||||||
|
new_resp = await agent.new_session(cwd="/tmp")
|
||||||
|
router = build_agent_router(agent)
|
||||||
|
|
||||||
|
mode_result = await router(
|
||||||
|
"session/set_mode",
|
||||||
|
{"modeId": "chat", "sessionId": new_resp.session_id},
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
config_result = await router(
|
||||||
|
"session/set_config_option",
|
||||||
|
{
|
||||||
|
"configId": "approval_mode",
|
||||||
|
"sessionId": new_resp.session_id,
|
||||||
|
"value": "auto",
|
||||||
|
},
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mode_result == {}
|
||||||
|
assert config_result == {"configOptions": []}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_router_accepts_unstable_model_switch_when_enabled(self, agent):
|
||||||
|
new_resp = await agent.new_session(cwd="/tmp")
|
||||||
|
router = build_agent_router(agent, use_unstable_protocol=True)
|
||||||
|
|
||||||
|
result = await router(
|
||||||
|
"session/set_model",
|
||||||
|
{"modelId": "gpt-5.4", "sessionId": new_resp.session_id},
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
state = agent.session_manager.get_session(new_resp.session_id)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
assert state.model == "gpt-5.4"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# prompt
|
# prompt
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ class TestGetProviderFallbackPriority:
|
|||||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||||
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
monkeypatch.setenv("VOICE_TOOLS_OPENAI_KEY", "sk-test")
|
||||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||||
|
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||||
from tools.transcription_tools import _get_provider
|
from tools.transcription_tools import _get_provider
|
||||||
assert _get_provider({}) == "groq"
|
assert _get_provider({}) == "groq"
|
||||||
@@ -130,9 +131,10 @@ class TestExplicitProviderRespected:
|
|||||||
def test_explicit_local_no_fallback_to_openai(self, monkeypatch):
|
def test_explicit_local_no_fallback_to_openai(self, monkeypatch):
|
||||||
"""GH-1774: provider=local must not silently fall back to openai
|
"""GH-1774: provider=local must not silently fall back to openai
|
||||||
even when an OpenAI API key is set."""
|
even when an OpenAI API key is set."""
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key-here")
|
monkeypatch.setenv("OPENAI_API_KEY", "***")
|
||||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||||
|
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||||
from tools.transcription_tools import _get_provider
|
from tools.transcription_tools import _get_provider
|
||||||
result = _get_provider({"provider": "local"})
|
result = _get_provider({"provider": "local"})
|
||||||
@@ -141,6 +143,7 @@ class TestExplicitProviderRespected:
|
|||||||
def test_explicit_local_no_fallback_to_groq(self, monkeypatch):
|
def test_explicit_local_no_fallback_to_groq(self, monkeypatch):
|
||||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||||
|
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||||
from tools.transcription_tools import _get_provider
|
from tools.transcription_tools import _get_provider
|
||||||
result = _get_provider({"provider": "local"})
|
result = _get_provider({"provider": "local"})
|
||||||
@@ -181,6 +184,7 @@ class TestExplicitProviderRespected:
|
|||||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||||
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
monkeypatch.delenv("GROQ_API_KEY", raising=False)
|
||||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||||
|
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||||
from tools.transcription_tools import _get_provider
|
from tools.transcription_tools import _get_provider
|
||||||
# Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect
|
# Empty dict = no explicit provider, uses DEFAULT_PROVIDER auto-detect
|
||||||
@@ -191,6 +195,7 @@ class TestExplicitProviderRespected:
|
|||||||
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
monkeypatch.setenv("GROQ_API_KEY", "gsk-test")
|
||||||
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
monkeypatch.setenv("OPENAI_API_KEY", "sk-real-key")
|
||||||
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
with patch("tools.transcription_tools._HAS_FASTER_WHISPER", False), \
|
||||||
|
patch("tools.transcription_tools._has_local_command", return_value=False), \
|
||||||
patch("tools.transcription_tools._HAS_OPENAI", True):
|
patch("tools.transcription_tools._HAS_OPENAI", True):
|
||||||
from tools.transcription_tools import _get_provider
|
from tools.transcription_tools import _get_provider
|
||||||
result = _get_provider({})
|
result = _get_provider({})
|
||||||
|
|||||||
Reference in New Issue
Block a user