fix: honor session-scoped gateway model overrides

This commit is contained in:
Hygaard
2026-04-11 04:24:45 -05:00
committed by Teknium
parent 671d5068e7
commit a2f9f04c06
3 changed files with 248 additions and 37 deletions

View File

@@ -667,6 +667,7 @@ class GatewayRunner:
def _flush_memories_for_session(
self,
old_session_id: str,
session_key: Optional[str] = None,
):
"""Prompt the agent to save memories/skills before context is lost.
@@ -685,15 +686,12 @@ class GatewayRunner:
return
from run_agent import AIAgent
runtime_kwargs = _resolve_runtime_agent_kwargs()
model, runtime_kwargs = self._resolve_session_agent_runtime(
session_key=session_key,
)
if not runtime_kwargs.get("api_key"):
return
# Resolve model from config — AIAgent's default is OpenRouter-
# formatted ("anthropic/claude-opus-4.6") which fails when the
# active provider is openai-codex.
model = _resolve_gateway_model()
tmp_agent = AIAgent(
**runtime_kwargs,
model=model,
@@ -773,6 +771,7 @@ class GatewayRunner:
async def _async_flush_memories(
self,
old_session_id: str,
session_key: Optional[str] = None,
):
"""Run the sync memory flush in a thread pool so it won't block the event loop."""
loop = asyncio.get_event_loop()
@@ -780,6 +779,7 @@ class GatewayRunner:
None,
self._flush_memories_for_session,
old_session_id,
session_key,
)
@property
@@ -814,6 +814,46 @@ class GatewayRunner:
thread_sessions_per_user=getattr(config, "thread_sessions_per_user", False),
)
def _resolve_session_agent_runtime(
self,
*,
source: Optional[SessionSource] = None,
session_key: Optional[str] = None,
user_config: Optional[dict] = None,
) -> tuple[str, dict]:
"""Resolve model/runtime for a session, honoring session-scoped /model overrides.
If the session override already contains a complete provider bundle
(provider/api_key/base_url/api_mode), prefer it directly instead of
resolving fresh global runtime state first.
"""
resolved_session_key = session_key
if not resolved_session_key and source is not None:
try:
resolved_session_key = self._session_key_for_source(source)
except Exception:
resolved_session_key = None
model = _resolve_gateway_model(user_config)
override = self._session_model_overrides.get(resolved_session_key) if resolved_session_key else None
if override:
override_model = override.get("model", model)
override_runtime = {
"provider": override.get("provider"),
"api_key": override.get("api_key"),
"base_url": override.get("base_url"),
"api_mode": override.get("api_mode"),
}
if override_runtime.get("api_key"):
return override_model, override_runtime
runtime_kwargs = _resolve_runtime_agent_kwargs()
if override and resolved_session_key:
model, runtime_kwargs = self._apply_session_model_override(
resolved_session_key, model, runtime_kwargs
)
return model, runtime_kwargs
def _resolve_turn_agent_config(self, user_message: str, model: str, runtime_kwargs: dict) -> dict:
from agent.smart_model_routing import resolve_turn_route
from hermes_cli.models import resolve_fast_mode_overrides
@@ -1598,7 +1638,7 @@ class GatewayRunner:
for key, entry in _expired_entries:
try:
await self._async_flush_memories(entry.session_id)
await self._async_flush_memories(entry.session_id, key)
# Shut down memory provider and close tool resources
# on the cached agent. Idle agents live in
# _agent_cache (not _running_agents), so look there.
@@ -2867,6 +2907,7 @@ class GatewayRunner:
_hyg_provider = None
_hyg_base_url = None
_hyg_api_key = None
_hyg_data = {}
try:
_hyg_cfg_path = _hermes_home / "config.yaml"
if _hyg_cfg_path.exists():
@@ -2901,15 +2942,17 @@ class GatewayRunner:
_comp_cfg.get("enabled", True)
).lower() in ("true", "1", "yes")
# Resolve provider/base_url from runtime if not in config
if not _hyg_provider or not _hyg_base_url:
try:
_hyg_runtime = _resolve_runtime_agent_kwargs()
_hyg_provider = _hyg_provider or _hyg_runtime.get("provider")
_hyg_base_url = _hyg_base_url or _hyg_runtime.get("base_url")
_hyg_api_key = _hyg_runtime.get("api_key")
except Exception:
pass
try:
_hyg_model, _hyg_runtime = self._resolve_session_agent_runtime(
source=source,
session_key=session_key,
user_config=_hyg_data if isinstance(_hyg_data, dict) else None,
)
_hyg_provider = _hyg_runtime.get("provider") or _hyg_provider
_hyg_base_url = _hyg_runtime.get("base_url") or _hyg_base_url
_hyg_api_key = _hyg_runtime.get("api_key") or _hyg_api_key
except Exception:
pass
# Check custom_providers per-model context_length
# (same fallback as run_agent.py lines 1171-1189).
@@ -2996,7 +3039,11 @@ class GatewayRunner:
try:
from run_agent import AIAgent
_hyg_runtime = _resolve_runtime_agent_kwargs()
_hyg_model, _hyg_runtime = self._resolve_session_agent_runtime(
source=source,
session_key=session_key,
user_config=_hyg_data if isinstance(_hyg_data, dict) else None,
)
if _hyg_runtime.get("api_key"):
_hyg_msgs = [
{"role": m.get("role"), "content": m.get("content")}
@@ -3652,7 +3699,7 @@ class GatewayRunner:
old_entry = self.session_store._entries.get(session_key)
if old_entry:
_flush_task = asyncio.create_task(
self._async_flush_memories(old_entry.session_id)
self._async_flush_memories(old_entry.session_id, session_key)
)
self._background_tasks.add(_flush_task)
_flush_task.add_done_callback(self._background_tasks.discard)
@@ -4973,7 +5020,11 @@ class GatewayRunner:
_thread_metadata = {"thread_id": source.thread_id} if source.thread_id else None
try:
runtime_kwargs = _resolve_runtime_agent_kwargs()
user_config = _load_gateway_config()
model, runtime_kwargs = self._resolve_session_agent_runtime(
source=source,
user_config=user_config,
)
if not runtime_kwargs.get("api_key"):
await adapter.send(
source.chat_id,
@@ -4982,8 +5033,6 @@ class GatewayRunner:
)
return
user_config = _load_gateway_config()
model = _resolve_gateway_model(user_config)
platform_key = _platform_config_key(source.platform)
from hermes_cli.tools_config import _get_platform_tools
@@ -5143,7 +5192,12 @@ class GatewayRunner:
_thread_meta = {"thread_id": source.thread_id} if source.thread_id else None
try:
runtime_kwargs = _resolve_runtime_agent_kwargs()
user_config = _load_gateway_config()
model, runtime_kwargs = self._resolve_session_agent_runtime(
source=source,
session_key=session_key,
user_config=user_config,
)
if not runtime_kwargs.get("api_key"):
await adapter.send(
source.chat_id,
@@ -5152,8 +5206,6 @@ class GatewayRunner:
)
return
user_config = _load_gateway_config()
model = _resolve_gateway_model(user_config)
platform_key = _platform_config_key(source.platform)
reasoning_config = self._load_reasoning_config()
self._service_tier = self._load_service_tier()
@@ -5490,13 +5542,14 @@ class GatewayRunner:
from agent.manual_compression_feedback import summarize_manual_compression
from agent.model_metadata import estimate_messages_tokens_rough
runtime_kwargs = _resolve_runtime_agent_kwargs()
session_key = self._session_key_for_source(source)
model, runtime_kwargs = self._resolve_session_agent_runtime(
source=source,
session_key=session_key,
)
if not runtime_kwargs.get("api_key"):
return "No provider configured -- cannot compress."
# Resolve model from config (same reason as memory flush above).
model = _resolve_gateway_model()
msgs = [
{"role": m.get("role"), "content": m.get("content")}
for m in history
@@ -5656,7 +5709,7 @@ class GatewayRunner:
# Flush memories for current session before switching
try:
_flush_task = asyncio.create_task(
self._async_flush_memories(current_entry.session_id)
self._async_flush_memories(current_entry.session_id, session_key)
)
self._background_tasks.add(_flush_task)
_flush_task.add_done_callback(self._background_tasks.discard)
@@ -7227,10 +7280,12 @@ class GatewayRunner:
except Exception:
pass
model = _resolve_gateway_model(user_config)
try:
runtime_kwargs = _resolve_runtime_agent_kwargs()
model, runtime_kwargs = self._resolve_session_agent_runtime(
source=source,
session_key=session_key,
user_config=user_config,
)
except Exception as exc:
return {
"final_response": f"⚠️ Provider authentication failed: {exc}",
@@ -7239,11 +7294,6 @@ class GatewayRunner:
"tools": [],
}
# /model overrides take precedence over config.yaml defaults.
model, runtime_kwargs = self._apply_session_model_override(
session_key, model, runtime_kwargs
)
pr = self._provider_routing
reasoning_config = self._load_reasoning_config()
self._reasoning_config = reasoning_config

View File

@@ -221,5 +221,6 @@ class TestHandleResumeCommand:
runner._async_flush_memories.assert_called_once_with(
"current_session_001",
"agent:main:telegram:dm:67890",
)
db.close()

View File

@@ -0,0 +1,160 @@
"""Regression tests for session-scoped model/provider overrides in gateway agents.
These cover the bug where `/model ...` stored a session override, but fresh
agent constructions still resolved model/provider from global config/runtime.
That let helper agents (and cache-miss main agents) route GPT-5.4 to the wrong
provider, e.g. Nous instead of OpenAI Codex.
"""
import asyncio
import sys
import threading
import types
from unittest.mock import AsyncMock, MagicMock
import pytest
import gateway.run as gateway_run
from gateway.config import Platform
from gateway.session import SessionSource
class _CapturingAgent:
"""Fake agent that records init kwargs for assertions."""
last_init = None
def __init__(self, *args, **kwargs):
type(self).last_init = dict(kwargs)
self.tools = []
def run_conversation(self, user_message: str, conversation_history=None, task_id=None):
return {
"final_response": "ok",
"messages": [],
"api_calls": 1,
}
def _make_runner():
runner = object.__new__(gateway_run.GatewayRunner)
runner.adapters = {}
runner.session_store = None
runner.config = None
runner._voice_mode = {}
runner._ephemeral_system_prompt = ""
runner._prefill_messages = []
runner._reasoning_config = None
runner._show_reasoning = False
runner._provider_routing = {}
runner._fallback_model = None
runner._service_tier = None
runner._running_agents = {}
runner._running_agents_ts = {}
runner._background_tasks = set()
runner._session_db = None
runner._session_model_overrides = {}
runner._pending_model_notes = {}
runner._pending_approvals = {}
runner._agent_cache = {}
runner._agent_cache_lock = threading.Lock()
runner._get_or_create_gateway_honcho = lambda session_key: (None, None)
runner.hooks = MagicMock()
runner.hooks.emit = AsyncMock()
runner.hooks.loaded_hooks = []
return runner
def _codex_override():
return {
"model": "gpt-5.4",
"provider": "openai-codex",
"api_key": "***",
"base_url": "https://chatgpt.com/backend-api/codex",
"api_mode": "codex_responses",
}
def _explode_runtime_resolution():
raise AssertionError(
"global runtime resolution should not run when a complete session override exists"
)
def test_run_agent_prefers_session_override_over_global_runtime(monkeypatch):
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
monkeypatch.setattr(gateway_run, "load_dotenv", lambda *args, **kwargs: None)
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
_CapturingAgent.last_init = None
runner = _make_runner()
source = SessionSource(
platform=Platform.LOCAL,
chat_id="cli",
chat_name="CLI",
chat_type="dm",
user_id="user-1",
)
session_key = "agent:main:local:dm"
runner._session_model_overrides[session_key] = _codex_override()
result = asyncio.run(
runner._run_agent(
message="ping",
context_prompt="",
history=[],
source=source,
session_id="session-1",
session_key=session_key,
)
)
assert result["final_response"] == "ok"
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
assert _CapturingAgent.last_init["provider"] == "openai-codex"
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
assert _CapturingAgent.last_init["api_key"] == "***"
@pytest.mark.asyncio
async def test_background_task_prefers_session_override_over_global_runtime(monkeypatch):
monkeypatch.setattr(gateway_run, "_load_gateway_config", lambda: {})
monkeypatch.setattr(gateway_run, "_resolve_runtime_agent_kwargs", _explode_runtime_resolution)
fake_run_agent = types.ModuleType("run_agent")
fake_run_agent.AIAgent = _CapturingAgent
monkeypatch.setitem(sys.modules, "run_agent", fake_run_agent)
_CapturingAgent.last_init = None
runner = _make_runner()
adapter = AsyncMock()
adapter.send = AsyncMock()
adapter.extract_media = MagicMock(return_value=([], "ok"))
adapter.extract_images = MagicMock(return_value=([], "ok"))
runner.adapters[Platform.TELEGRAM] = adapter
source = SessionSource(
platform=Platform.TELEGRAM,
user_id="12345",
chat_id="67890",
user_name="testuser",
)
session_key = runner._session_key_for_source(source)
runner._session_model_overrides[session_key] = _codex_override()
await runner._run_background_task("say hello", source, "bg_test")
assert _CapturingAgent.last_init is not None
assert _CapturingAgent.last_init["model"] == "gpt-5.4"
assert _CapturingAgent.last_init["provider"] == "openai-codex"
assert _CapturingAgent.last_init["api_mode"] == "codex_responses"
assert _CapturingAgent.last_init["base_url"] == "https://chatgpt.com/backend-api/codex"
assert _CapturingAgent.last_init["api_key"] == "***"