fix: honor session-scoped gateway model overrides
This commit is contained in:
124
gateway/run.py
124
gateway/run.py
@@ -667,6 +667,7 @@ class GatewayRunner:
|
||||
def _flush_memories_for_session(
|
||||
self,
|
||||
old_session_id: str,
|
||||
session_key: Optional[str] = None,
|
||||
):
|
||||
"""Prompt the agent to save memories/skills before context is lost.
|
||||
|
||||
@@ -685,15 +686,12 @@ class GatewayRunner:
|
||||
return
|
||||
|
||||
from run_agent import AIAgent
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model, runtime_kwargs = self._resolve_session_agent_runtime(
|
||||
session_key=session_key,
|
||||
)
|
||||
if not runtime_kwargs.get("api_key"):
|
||||
return
|
||||
|
||||
# Resolve model from config — AIAgent's default is OpenRouter-
|
||||
# formatted ("anthropic/claude-opus-4.6") which fails when the
|
||||
# active provider is openai-codex.
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
tmp_agent = AIAgent(
|
||||
**runtime_kwargs,
|
||||
model=model,
|
||||
@@ -773,6 +771,7 @@ class GatewayRunner:
|
||||
async def _async_flush_memories(
|
||||
self,
|
||||
old_session_id: str,
|
||||
session_key: Optional[str] = None,
|
||||
):
|
||||
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -780,6 +779,7 @@ class GatewayRunner:
|
||||
None,
|
||||
self._flush_memories_for_session,
|
||||
old_session_id,
|
||||
session_key,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -814,6 +814,46 @@ class GatewayRunner:
|
||||
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
|
||||
)
|
||||
|
||||
def _resolve_session_agent_runtime(
|
||||
self,
|
||||
*,
|
||||
source: Optional[SessionSource] = None,
|
||||
session_key: Optional[str] = None,
|
||||
user_config: Optional[dict] = None,
|
||||
) -> tuple[str, dict]:
|
||||
"""Resolve model/runtime for a session, honoring session-scoped /model overrides.
|
||||
|
||||
If the session override already contains a complete provider bundle
|
||||
(provider/api_key/base_url/api_mode), prefer it directly instead of
|
||||
resolving fresh global runtime state first.
|
||||
"""
|
||||
resolved_session_key = session_key
|
||||
if not resolved_session_key and source is not None:
|
||||
try:
|
||||
resolved_session_key = self._session_key_for_source(source)
|
||||
except Exception:
|
||||
resolved_session_key = None
|
||||
|
||||
model = _resolve_gateway_model(user_config)
|
||||
override = self._session_model_overrides.get(resolved_session_key) if resolved_session_key else None
|
||||
if override:
|
||||
override_model = override.get("model", model)
|
||||
override_runtime = {
|
||||
"provider": override.get("provider"),
|
||||
"api_key": override.get("api_key"),
|
||||
"base_url": override.get("base_url"),
|
||||
"api_mode": override.get("api_mode"),
|
||||
}
|
||||
if override_runtime.get("api_key"):
|
||||
return override_model, override_runtime
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
if override and resolved_session_key:
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
resolved_session_key, model, runtime_kwargs
|
||||
)
|
||||
return model, runtime_kwargs
|
||||
|
||||
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
|
||||
from agent.smart_model_routing import resolve_turn_route
|
||||
from hermes_cli.models import resolve_fast_mode_overrides
|
||||
@@ -1598,7 +1638,7 @@ class GatewayRunner:
|
||||
|
||||
for key, entry in _expired_entries:
|
||||
try:
|
||||
await self._async_flush_memories(entry.session_id)
|
||||
await self._async_flush_memories(entry.session_id, key)
|
||||
# Shut down memory provider and close tool resources
|
||||
# on the cached agent. Idle agents live in
|
||||
# _agent_cache (not _running_agents), so look there.
|
||||
@@ -2867,6 +2907,7 @@ class GatewayRunner:
|
||||
_hyg_provider = None
|
||||
_hyg_base_url = None
|
||||
_hyg_api_key = None
|
||||
_hyg_data = {}
|
||||
try:
|
||||
_hyg_cfg_path = _hermes_home / "config.yaml"
|
||||
if _hyg_cfg_path.exists():
|
||||
@@ -2901,15 +2942,17 @@ class GatewayRunner:
|
||||
_comp_cfg.get("enabled", True)
|
||||
).lower() in ("true", "1", "yes")
|
||||
|
||||
# Resolve provider/base_url from runtime if not in config
|
||||
if not _hyg_provider or not _hyg_base_url:
|
||||
try:
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
_hyg_provider = _hyg_provider or _hyg_runtime.get("provider")
|
||||
_hyg_base_url = _hyg_base_url or _hyg_runtime.get("base_url")
|
||||
_hyg_api_key = _hyg_runtime.get("api_key")
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_hyg_model, _hyg_runtime = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
user_config=_hyg_data if isinstance(_hyg_data, dict) else None,
|
||||
)
|
||||
_hyg_provider = _hyg_runtime.get("provider") or _hyg_provider
|
||||
_hyg_base_url = _hyg_runtime.get("base_url") or _hyg_base_url
|
||||
_hyg_api_key = _hyg_runtime.get("api_key") or _hyg_api_key
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check custom_providers per-model context_length
|
||||
# (same fallback as run_agent.py lines 1171-1189).
|
||||
@@ -2996,7 +3039,11 @@ class GatewayRunner:
|
||||
try:
|
||||
from run_agent import AIAgent
|
||||
|
||||
_hyg_runtime = _resolve_runtime_agent_kwargs()
|
||||
_hyg_model, _hyg_runtime = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
user_config=_hyg_data if isinstance(_hyg_data, dict) else None,
|
||||
)
|
||||
if _hyg_runtime.get("api_key"):
|
||||
_hyg_msgs = [
|
||||
{"role": m.get("role"), "content": m.get("content")}
|
||||
@@ -3652,7 +3699,7 @@ class GatewayRunner:
|
||||
old_entry = self.session_store._entries.get(session_key)
|
||||
if old_entry:
|
||||
_flush_task = asyncio.create_task(
|
||||
self._async_flush_memories(old_entry.session_id)
|
||||
self._async_flush_memories(old_entry.session_id, session_key)
|
||||
)
|
||||
self._background_tasks.add(_flush_task)
|
||||
_flush_task.add_done_callback(self._background_tasks.discard)
|
||||
@@ -4973,7 +5020,11 @@ class GatewayRunner:
|
||||
_thread_metadata = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
|
||||
try:
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
user_config = _load_gateway_config()
|
||||
model, runtime_kwargs = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
user_config=user_config,
|
||||
)
|
||||
if not runtime_kwargs.get("api_key"):
|
||||
await adapter.send(
|
||||
source.chat_id,
|
||||
@@ -4982,8 +5033,6 @@ class GatewayRunner:
|
||||
)
|
||||
return
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
model = _resolve_gateway_model(user_config)
|
||||
platform_key = _platform_config_key(source.platform)
|
||||
|
||||
from hermes_cli.tools_config import _get_platform_tools
|
||||
@@ -5143,7 +5192,12 @@ class GatewayRunner:
|
||||
_thread_meta = {"thread_id": source.thread_id} if source.thread_id else None
|
||||
|
||||
try:
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
user_config = _load_gateway_config()
|
||||
model, runtime_kwargs = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
user_config=user_config,
|
||||
)
|
||||
if not runtime_kwargs.get("api_key"):
|
||||
await adapter.send(
|
||||
source.chat_id,
|
||||
@@ -5152,8 +5206,6 @@ class GatewayRunner:
|
||||
)
|
||||
return
|
||||
|
||||
user_config = _load_gateway_config()
|
||||
model = _resolve_gateway_model(user_config)
|
||||
platform_key = _platform_config_key(source.platform)
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._service_tier = self._load_service_tier()
|
||||
@@ -5490,13 +5542,14 @@ class GatewayRunner:
|
||||
from agent.manual_compression_feedback import summarize_manual_compression
|
||||
from agent.model_metadata import estimate_messages_tokens_rough
|
||||
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
session_key = self._session_key_for_source(source)
|
||||
model, runtime_kwargs = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
)
|
||||
if not runtime_kwargs.get("api_key"):
|
||||
return "No provider configured -- cannot compress."
|
||||
|
||||
# Resolve model from config (same reason as memory flush above).
|
||||
model = _resolve_gateway_model()
|
||||
|
||||
msgs = [
|
||||
{"role": m.get("role"), "content": m.get("content")}
|
||||
for m in history
|
||||
@@ -5656,7 +5709,7 @@ class GatewayRunner:
|
||||
# Flush memories for current session before switching
|
||||
try:
|
||||
_flush_task = asyncio.create_task(
|
||||
self._async_flush_memories(current_entry.session_id)
|
||||
self._async_flush_memories(current_entry.session_id, session_key)
|
||||
)
|
||||
self._background_tasks.add(_flush_task)
|
||||
_flush_task.add_done_callback(self._background_tasks.discard)
|
||||
@@ -7227,10 +7280,12 @@ class GatewayRunner:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = _resolve_gateway_model(user_config)
|
||||
|
||||
try:
|
||||
runtime_kwargs = _resolve_runtime_agent_kwargs()
|
||||
model, runtime_kwargs = self._resolve_session_agent_runtime(
|
||||
source=source,
|
||||
session_key=session_key,
|
||||
user_config=user_config,
|
||||
)
|
||||
except Exception as exc:
|
||||
return {
|
||||
"final_response": f"⚠️ Provider authentication failed: {exc}",
|
||||
@@ -7239,11 +7294,6 @@ class GatewayRunner:
|
||||
"tools": [],
|
||||
}
|
||||
|
||||
# /model overrides take precedence over config.yaml defaults.
|
||||
model, runtime_kwargs = self._apply_session_model_override(
|
||||
session_key, model, runtime_kwargs
|
||||
)
|
||||
|
||||
pr = self._provider_routing
|
||||
reasoning_config = self._load_reasoning_config()
|
||||
self._reasoning_config = reasoning_config
|
||||
|
||||
@@ -221,5 +221,6 @@ class TestHandleResumeCommand:
|
||||
|
||||
runner._async_flush_memories.assert_called_once_with(
|
||||
"current_session_001",
|
||||
"agent:main:telegram:dm:67890",
|
||||
)
|
||||
db.close()
|
||||
|
||||
160
tests/gateway/test_session_model_override_routing.py
Normal file
160
tests/gateway/test_session_model_override_routing.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""Regression tests for session-scoped model/provider overrides in gateway agents.
|
||||
|
||||
These cover the bug where `/model ...` stored a session override, but fresh
|
||||
agent constructions still resolved model/provider from global config/runtime.
|
||||
That let helper agents (and cache-miss main agents) route GPT-5.4 to the wrong
|
||||
provider, e.g. Nous instead of OpenAI Codex.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import gateway.run as gateway_run
|
||||
from gateway.config import Platform
|
||||
from gateway.session import SessionSource
|
||||
|
||||
|
||||
class _CapturingAgent:
|
||||
"""Fake agent that records init kwargs for assertions."""
|
||||
|
||||
last_init = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).last_init = dict(kwargs)
|
||||
self.tools = []
|
||||
|
||||
def run_conversation(self, user_message: str, conversation_history=None, task_id=None):
|
||||
return {
|
||||
"final_response": "ok",
|
||||
"messages": [],
|
||||
"api_calls": 1,
|
||||
}
|
||||
|
||||
|
||||
def _make_runner():
|
||||
runner = object.__new__(gateway_run.GatewayRunner)
|
||||
runner.adapters = {}
|
||||
runner.session_store = None
|
||||
runner.config = None
|
||||
runner._voice_mode = {}
|
||||
runner._ephemeral_system_prompt = ""
|
||||
runner._prefill_messages = []
|
||||
runner._reasoning_config = None
|
||||
runner._show_reasoning = False
|
||||
runner._provider_routing = {}
|
||||
runner._fallback_model = None
|
||||
runner._service_tier = None
|
||||
runner._running_agents = {}
|
||||
runner._running_agents_ts = {}
|
||||
runner._background_tasks = set()
|
||||
runner._session_db = None
|
||||
runner._session_model_overrides = {}
|
||||
runner._pending_model_notes = {}
|
||||
runner._pending_approvals = {}
|
||||
runner._agent_cache = {}
|
||||
runner._agent_cache_lock = threading.Lock()
|
||||
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
|
||||
runner.hooks = MagicMock()
|
||||
runner.hooks.emit = AsyncMock()
|
||||
runner.hooks.loaded_hooks = []
|
||||
return runner
|
||||
|
||||
|
||||
def _codex_override():
|
||||
return {
|
||||
"model": "gpt-5.4",
|
||||
"provider": "openai-codex",
|
||||
"api_key": "***",
|
||||
"base_url": "https://chatgpt.com/backend-api/codex",
|
||||
"api_mode": "codex_responses",
|
||||
}
|
||||
|
||||
|
||||
def _explode_runtime_resolution():
|
||||
raise AssertionError(
|
||||
"global runtime resolution should not run when a complete session override exists"
|
||||
)
|
||||
|
||||
|
||||
def test_run_agent_prefers_session_override_over_global_runtime(monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.LOCAL,
|
||||
chat_id="cli",
|
||||
chat_name="CLI",
|
||||
chat_type="dm",
|
||||
user_id="user-1",
|
||||
)
|
||||
session_key = "agent:main:local:dm"
|
||||
runner._session_model_overrides[session_key] = _codex_override()
|
||||
|
||||
result = asyncio.run(
|
||||
runner._run_agent(
|
||||
message="ping",
|
||||
context_prompt="",
|
||||
history=[],
|
||||
source=source,
|
||||
session_id="session-1",
|
||||
session_key=session_key,
|
||||
)
|
||||
)
|
||||
|
||||
assert result["final_response"] == "ok"
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
|
||||
assert _CapturingAgent.last_init["provider"] == "openai-codex"
|
||||
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
|
||||
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert _CapturingAgent.last_init["api_key"] == "***"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_task_prefers_session_override_over_global_runtime(monkeypatch):
|
||||
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
|
||||
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
|
||||
|
||||
fake_run_agent = types.ModuleType("run_agent")
|
||||
fake_run_agent.AIAgent = _CapturingAgent
|
||||
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
|
||||
|
||||
_CapturingAgent.last_init = None
|
||||
runner = _make_runner()
|
||||
|
||||
adapter = AsyncMock()
|
||||
adapter.send = AsyncMock()
|
||||
adapter.extract_media = MagicMock(return_value=([], "ok"))
|
||||
adapter.extract_images = MagicMock(return_value=([], "ok"))
|
||||
runner.adapters[Platform.TELEGRAM] = adapter
|
||||
|
||||
source = SessionSource(
|
||||
platform=Platform.TELEGRAM,
|
||||
user_id="12345",
|
||||
chat_id="67890",
|
||||
user_name="testuser",
|
||||
)
|
||||
session_key = runner._session_key_for_source(source)
|
||||
runner._session_model_overrides[session_key] = _codex_override()
|
||||
|
||||
await runner._run_background_task("say hello", source, "bg_test")
|
||||
|
||||
assert _CapturingAgent.last_init is not None
|
||||
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
|
||||
assert _CapturingAgent.last_init["provider"] == "openai-codex"
|
||||
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
|
||||
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
|
||||
assert _CapturingAgent.last_init["api_key"] == "***"
|
||||
Reference in New Issue
Block a user