fix: provider/model resolution — salvage 4 PRs + MiniMax aux URL fix (#5983)
Salvaged fixes from community PRs: - fix(model_switch): _read_auth_store → _load_auth_store + fix auth store key lookup (was checking top-level dict instead of store['providers']). OAuth providers now correctly detected in /model picker. Cherry-picked from PR #5911 by Xule Lin (linxule). - fix(ollama): pass num_ctx to override 2048 default context window. Ollama defaults to 2048 context regardless of model capabilities. Now auto-detects from /api/show metadata and injects num_ctx into every request. Config override via model.ollama_num_ctx. Fixes #2708. Cherry-picked from PR #5929 by kshitij (kshitijk4poor). - fix(aux): normalize provider aliases for vision/auxiliary routing. Adds _normalize_aux_provider() with 17 aliases (google→gemini, claude→anthropic, glm→zai, etc). Fixes vision routing failure when provider is set to 'google' instead of 'gemini'. Cherry-picked from PR #5793 by e11i (Elizabeth1979). - fix(aux): rewrite MiniMax /anthropic base URLs to /v1 for OpenAI SDK. MiniMax's inference_base_url ends in /anthropic (Anthropic Messages API), but auxiliary client uses OpenAI SDK which appends /chat/completions → 404 at /anthropic/chat/completions. Generic _to_openai_base_url() helper rewrites terminal /anthropic to /v1 for OpenAI-compatible endpoint. Inspired by PR #5786 by Lempkey. Added debug logging to silent exception blocks across all fixes. Co-authored-by: Hermes Agent <hermes@nousresearch.com>
This commit is contained in:
@@ -59,6 +59,41 @@ from hermes_constants import OPENROUTER_BASE_URL
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_PROVIDER_ALIASES = {
|
||||
"google": "gemini",
|
||||
"google-gemini": "gemini",
|
||||
"google-ai-studio": "gemini",
|
||||
"glm": "zai",
|
||||
"z-ai": "zai",
|
||||
"z.ai": "zai",
|
||||
"zhipu": "zai",
|
||||
"kimi": "kimi-coding",
|
||||
"moonshot": "kimi-coding",
|
||||
"minimax-china": "minimax-cn",
|
||||
"minimax_cn": "minimax-cn",
|
||||
"claude": "anthropic",
|
||||
"claude-code": "anthropic",
|
||||
}
|
||||
|
||||
|
||||
def _normalize_aux_provider(provider: Optional[str], *, for_vision: bool = False) -> str:
|
||||
normalized = (provider or "auto").strip().lower()
|
||||
if normalized.startswith("custom:"):
|
||||
suffix = normalized.split(":", 1)[1].strip()
|
||||
if not suffix:
|
||||
return "custom"
|
||||
normalized = suffix if not for_vision else "custom"
|
||||
if normalized == "codex":
|
||||
return "openai-codex"
|
||||
if normalized == "main":
|
||||
# Resolve to the user's actual main provider so named custom providers
|
||||
# and non-aggregator providers (DeepSeek, Alibaba, etc.) work correctly.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
return main_prov
|
||||
return "custom"
|
||||
return _PROVIDER_ALIASES.get(normalized, normalized)
|
||||
|
||||
# Default auxiliary models for direct API-key providers (cheap/fast for side tasks)
|
||||
_API_KEY_PROVIDER_AUX_MODELS: Dict[str, str] = {
|
||||
"gemini": "gemini-3-flash-preview",
|
||||
@@ -106,6 +141,23 @@ _CODEX_AUX_MODEL = "gpt-5.2-codex"
|
||||
_CODEX_AUX_BASE_URL = "https://chatgpt.com/backend-api/codex"
|
||||
|
||||
|
||||
def _to_openai_base_url(base_url: str) -> str:
|
||||
"""Normalize an Anthropic-style base URL to OpenAI-compatible format.
|
||||
|
||||
Some providers (MiniMax, MiniMax-CN) expose an ``/anthropic`` endpoint for
|
||||
the Anthropic Messages API and a separate ``/v1`` endpoint for OpenAI chat
|
||||
completions. The auxiliary client uses the OpenAI SDK, so it must hit the
|
||||
``/v1`` surface. Passing the raw ``inference_base_url`` causes requests to
|
||||
land on ``/anthropic/chat/completions`` — a 404.
|
||||
"""
|
||||
url = str(base_url or "").strip().rstrip("/")
|
||||
if url.endswith("/anthropic"):
|
||||
rewritten = url[: -len("/anthropic")] + "/v1"
|
||||
logger.debug("Auxiliary client: rewrote base URL %s → %s", url, rewritten)
|
||||
return rewritten
|
||||
return url
|
||||
|
||||
|
||||
def _select_pool_entry(provider: str) -> Tuple[bool, Optional[Any]]:
|
||||
"""Return (pool_exists_for_provider, selected_entry)."""
|
||||
try:
|
||||
@@ -635,7 +687,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = _pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
_pool_runtime_base_url(entry, pconfig.inference_base_url) or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s) via pool", pconfig.name, model)
|
||||
extra = {}
|
||||
@@ -652,7 +706,9 @@ def _resolve_api_key_provider() -> Tuple[Optional[OpenAI], Optional[str]]:
|
||||
if not api_key:
|
||||
continue
|
||||
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
)
|
||||
model = _API_KEY_PROVIDER_AUX_MODELS.get(provider_id, "default")
|
||||
logger.debug("Auxiliary text client: %s (%s)", pconfig.name, model)
|
||||
extra = {}
|
||||
@@ -778,7 +834,7 @@ def _read_main_provider() -> str:
|
||||
if isinstance(model_cfg, dict):
|
||||
provider = model_cfg.get("provider", "")
|
||||
if isinstance(provider, str) and provider.strip():
|
||||
return provider.strip().lower()
|
||||
return _normalize_aux_provider(provider)
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
@@ -1140,17 +1196,7 @@ def resolve_provider_client(
|
||||
(client, resolved_model) or (None, None) if auth is unavailable.
|
||||
"""
|
||||
# Normalise aliases
|
||||
provider = (provider or "auto").strip().lower()
|
||||
if provider == "codex":
|
||||
provider = "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to the user's actual main provider so named custom providers
|
||||
# and non-aggregator providers (DeepSeek, Alibaba, etc.) work correctly.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
provider = main_prov
|
||||
else:
|
||||
provider = "custom"
|
||||
provider = _normalize_aux_provider(provider)
|
||||
|
||||
# ── Auto: try all providers in priority order ────────────────────
|
||||
if provider == "auto":
|
||||
@@ -1300,7 +1346,9 @@ def resolve_provider_client(
|
||||
provider, ", ".join(tried_sources))
|
||||
return None, None
|
||||
|
||||
base_url = str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
base_url = _to_openai_base_url(
|
||||
str(creds.get("base_url", "")).strip().rstrip("/") or pconfig.inference_base_url
|
||||
)
|
||||
|
||||
default_model = _API_KEY_PROVIDER_AUX_MODELS.get(provider, "")
|
||||
final_model = model or default_model
|
||||
@@ -1384,17 +1432,7 @@ _VISION_AUTO_PROVIDER_ORDER = (
|
||||
|
||||
|
||||
def _normalize_vision_provider(provider: Optional[str]) -> str:
|
||||
provider = (provider or "auto").strip().lower()
|
||||
if provider == "codex":
|
||||
return "openai-codex"
|
||||
if provider == "main":
|
||||
# Resolve to actual main provider — named custom providers and
|
||||
# non-aggregator providers need to pass through as their real name.
|
||||
main_prov = _read_main_provider()
|
||||
if main_prov and main_prov not in ("auto", "main", ""):
|
||||
return main_prov
|
||||
return "custom"
|
||||
return provider
|
||||
return _normalize_aux_provider(provider, for_vision=True)
|
||||
|
||||
|
||||
def _resolve_strict_vision_backend(provider: str) -> Tuple[Optional[Any], Optional[str]]:
|
||||
|
||||
@@ -611,6 +611,59 @@ def _model_id_matches(candidate_id: str, lookup_model: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def query_ollama_num_ctx(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query an Ollama server for the model's context length.
|
||||
|
||||
Returns the model's maximum context from GGUF metadata via ``/api/show``,
|
||||
or the explicit ``num_ctx`` from the Modelfile if set. Returns None if
|
||||
the server is unreachable or not Ollama.
|
||||
|
||||
This is the value that should be passed as ``num_ctx`` in Ollama chat
|
||||
requests to override the default 2048.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
bare_model = _strip_provider_prefix(model)
|
||||
server_url = base_url.rstrip("/")
|
||||
if server_url.endswith("/v1"):
|
||||
server_url = server_url[:-3]
|
||||
|
||||
try:
|
||||
server_type = detect_local_server_type(base_url)
|
||||
except Exception:
|
||||
return None
|
||||
if server_type != "ollama":
|
||||
return None
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=3.0) as client:
|
||||
resp = client.post(f"{server_url}/api/show", json={"name": bare_model})
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
data = resp.json()
|
||||
|
||||
# Prefer explicit num_ctx from Modelfile parameters (user override)
|
||||
params = data.get("parameters", "")
|
||||
if "num_ctx" in params:
|
||||
for line in params.split("\n"):
|
||||
if "num_ctx" in line:
|
||||
parts = line.strip().split()
|
||||
if len(parts) >= 2:
|
||||
try:
|
||||
return int(parts[-1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Fall back to GGUF model_info context_length (training max)
|
||||
model_info = data.get("model_info", {})
|
||||
for key, value in model_info.items():
|
||||
if "context_length" in key and isinstance(value, (int, float)):
|
||||
return int(value)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def _query_local_context_length(model: str, base_url: str) -> Optional[int]:
|
||||
"""Query a local server for the model's context length."""
|
||||
import httpx
|
||||
|
||||
@@ -791,12 +791,12 @@ def list_authenticated_providers(
|
||||
if overlay.auth_type in ("oauth_device_code", "oauth_external", "external_process"):
|
||||
# These use auth stores, not env vars — check for auth.json entries
|
||||
try:
|
||||
from hermes_cli.auth import _read_auth_store
|
||||
store = _read_auth_store()
|
||||
if store and pid in store:
|
||||
from hermes_cli.auth import _load_auth_store
|
||||
store = _load_auth_store()
|
||||
if store and (pid in store.get("providers", {}) or pid in store.get("credential_pool", {})):
|
||||
has_creds = True
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.debug("Auth store check failed for %s: %s", pid, exc)
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
|
||||
37
run_agent.py
37
run_agent.py
@@ -85,6 +85,7 @@ from agent.model_metadata import (
|
||||
estimate_tokens_rough, estimate_messages_tokens_rough, estimate_request_tokens_rough,
|
||||
get_next_probe_tier, parse_context_limit_from_error,
|
||||
save_context_length, is_local_endpoint,
|
||||
query_ollama_num_ctx,
|
||||
)
|
||||
from agent.context_compressor import ContextCompressor
|
||||
from agent.subdirectory_hints import SubdirectoryHintTracker
|
||||
@@ -1216,6 +1217,33 @@ class AIAgent:
|
||||
self.session_cost_status = "unknown"
|
||||
self.session_cost_source = "none"
|
||||
|
||||
# ── Ollama num_ctx injection ──
|
||||
# Ollama defaults to 2048 context regardless of the model's capabilities.
|
||||
# When running against an Ollama server, detect the model's max context
|
||||
# and pass num_ctx on every chat request so the full window is used.
|
||||
# User override: set model.ollama_num_ctx in config.yaml to cap VRAM use.
|
||||
self._ollama_num_ctx: int | None = None
|
||||
_ollama_num_ctx_override = None
|
||||
if isinstance(_model_cfg, dict):
|
||||
_ollama_num_ctx_override = _model_cfg.get("ollama_num_ctx")
|
||||
if _ollama_num_ctx_override is not None:
|
||||
try:
|
||||
self._ollama_num_ctx = int(_ollama_num_ctx_override)
|
||||
except (TypeError, ValueError):
|
||||
logger.debug("Invalid ollama_num_ctx config value: %r", _ollama_num_ctx_override)
|
||||
if self._ollama_num_ctx is None and self.base_url and is_local_endpoint(self.base_url):
|
||||
try:
|
||||
_detected = query_ollama_num_ctx(self.model, self.base_url)
|
||||
if _detected and _detected > 0:
|
||||
self._ollama_num_ctx = _detected
|
||||
except Exception as exc:
|
||||
logger.debug("Ollama num_ctx detection failed: %s", exc)
|
||||
if self._ollama_num_ctx and not self.quiet_mode:
|
||||
logger.info(
|
||||
"Ollama num_ctx: will request %d tokens (model max from /api/show)",
|
||||
self._ollama_num_ctx,
|
||||
)
|
||||
|
||||
if not self.quiet_mode:
|
||||
if compression_enabled:
|
||||
print(f"📊 Context limit: {self.context_compressor.context_length:,} tokens (compress at {int(compression_threshold*100)}% = {self.context_compressor.threshold_tokens:,})")
|
||||
@@ -5456,6 +5484,15 @@ class AIAgent:
|
||||
if _is_nous:
|
||||
extra_body["tags"] = ["product=hermes-agent"]
|
||||
|
||||
# Ollama num_ctx: override the 2048 default so the model actually
|
||||
# uses the context window it was trained for. Passed via the OpenAI
|
||||
# SDK's extra_body → options.num_ctx, which Ollama's OpenAI-compat
|
||||
# endpoint forwards to the runner as --ctx-size.
|
||||
if self._ollama_num_ctx:
|
||||
options = extra_body.get("options", {})
|
||||
options["num_ctx"] = self._ollama_num_ctx
|
||||
extra_body["options"] = options
|
||||
|
||||
if extra_body:
|
||||
api_kwargs["extra_body"] = extra_body
|
||||
|
||||
|
||||
@@ -471,6 +471,23 @@ class TestExplicitProviderRouting:
|
||||
client, model = resolve_provider_client("zai")
|
||||
assert client is not None
|
||||
|
||||
def test_explicit_google_alias_uses_gemini_credentials(self):
|
||||
"""provider='google' should route through the gemini API-key provider."""
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
mock_openai.return_value = MagicMock()
|
||||
client, model = resolve_provider_client("google", model="gemini-3.1-pro-preview")
|
||||
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_explicit_unknown_returns_none(self, monkeypatch):
|
||||
"""Unknown provider should return None."""
|
||||
client, model = resolve_provider_client("nonexistent-provider")
|
||||
@@ -822,6 +839,31 @@ class TestAuxiliaryPoolAwareness:
|
||||
assert model == "google/gemini-3-flash-preview"
|
||||
assert client is not None
|
||||
|
||||
def test_vision_config_google_provider_uses_gemini_credentials(self, monkeypatch):
|
||||
config = {
|
||||
"auxiliary": {
|
||||
"vision": {
|
||||
"provider": "google",
|
||||
"model": "gemini-3.1-pro-preview",
|
||||
}
|
||||
}
|
||||
}
|
||||
monkeypatch.setattr("hermes_cli.config.load_config", lambda: config)
|
||||
with (
|
||||
patch("hermes_cli.auth.resolve_api_key_provider_credentials", return_value={
|
||||
"api_key": "gemini-key",
|
||||
"base_url": "https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
}),
|
||||
patch("agent.auxiliary_client.OpenAI") as mock_openai,
|
||||
):
|
||||
resolved_provider, client, model = resolve_vision_provider_client()
|
||||
|
||||
assert resolved_provider == "gemini"
|
||||
assert client is not None
|
||||
assert model == "gemini-3.1-pro-preview"
|
||||
assert mock_openai.call_args.kwargs["api_key"] == "gemini-key"
|
||||
assert mock_openai.call_args.kwargs["base_url"] == "https://generativelanguage.googleapis.com/v1beta/openai"
|
||||
|
||||
def test_vision_forced_main_uses_custom_endpoint(self, monkeypatch):
|
||||
"""When explicitly forced to 'main', vision CAN use custom endpoint."""
|
||||
config = {
|
||||
|
||||
42
tests/agent/test_minimax_auxiliary_url.py
Normal file
42
tests/agent/test_minimax_auxiliary_url.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Tests for MiniMax auxiliary client URL normalization.
|
||||
|
||||
MiniMax and MiniMax-CN set inference_base_url to the /anthropic path.
|
||||
The auxiliary client uses the OpenAI SDK, which needs /v1 instead.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
|
||||
from agent.auxiliary_client import _to_openai_base_url
|
||||
|
||||
|
||||
class TestToOpenaiBaseUrl:
|
||||
def test_minimax_global_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_minimax_cn_anthropic_suffix_replaced(self):
|
||||
assert _to_openai_base_url("https://api.minimaxi.com/anthropic") == "https://api.minimaxi.com/v1"
|
||||
|
||||
def test_trailing_slash_stripped_before_replace(self):
|
||||
assert _to_openai_base_url("https://api.minimax.io/anthropic/") == "https://api.minimax.io/v1"
|
||||
|
||||
def test_v1_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://api.openai.com/v1") == "https://api.openai.com/v1"
|
||||
|
||||
def test_openrouter_url_unchanged(self):
|
||||
assert _to_openai_base_url("https://openrouter.ai/api/v1") == "https://openrouter.ai/api/v1"
|
||||
|
||||
def test_anthropic_domain_unchanged(self):
|
||||
"""api.anthropic.com doesn't end with /anthropic — should be untouched."""
|
||||
assert _to_openai_base_url("https://api.anthropic.com") == "https://api.anthropic.com"
|
||||
|
||||
def test_anthropic_in_subpath_unchanged(self):
|
||||
assert _to_openai_base_url("https://example.com/anthropic/extra") == "https://example.com/anthropic/extra"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert _to_openai_base_url("") == ""
|
||||
|
||||
def test_none(self):
|
||||
assert _to_openai_base_url(None) == ""
|
||||
135
tests/test_ollama_num_ctx.py
Normal file
135
tests/test_ollama_num_ctx.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Tests for Ollama num_ctx context length detection and injection.
|
||||
|
||||
Covers:
|
||||
agent/model_metadata.py — query_ollama_num_ctx()
|
||||
run_agent.py — _ollama_num_ctx detection + extra_body injection
|
||||
"""
|
||||
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from agent.model_metadata import query_ollama_num_ctx
|
||||
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
# Level 1: query_ollama_num_ctx — Ollama API interaction
|
||||
# ═══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _mock_httpx_client(show_response_data, status_code=200):
|
||||
"""Create a mock httpx.Client context manager that returns given /api/show data."""
|
||||
mock_resp = MagicMock(status_code=status_code)
|
||||
mock_resp.json.return_value = show_response_data
|
||||
mock_client = MagicMock()
|
||||
mock_client.post.return_value = mock_resp
|
||||
mock_ctx = MagicMock()
|
||||
mock_ctx.__enter__ = MagicMock(return_value=mock_client)
|
||||
mock_ctx.__exit__ = MagicMock(return_value=False)
|
||||
return mock_ctx, mock_client
|
||||
|
||||
|
||||
class TestQueryOllamaNumCtx:
|
||||
"""Test the Ollama /api/show context length query."""
|
||||
|
||||
def test_returns_context_from_model_info(self):
|
||||
"""Should extract context_length from GGUF model_info metadata."""
|
||||
show_data = {
|
||||
"model_info": {"llama.context_length": 131072},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
# httpx is imported inside the function — patch the module import
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("llama3.1:8b", "http://localhost:11434/v1")
|
||||
|
||||
assert result == 131072
|
||||
|
||||
def test_prefers_explicit_num_ctx_from_modelfile(self):
|
||||
"""If the Modelfile sets num_ctx explicitly, that should take priority."""
|
||||
show_data = {
|
||||
"model_info": {"llama.context_length": 131072},
|
||||
"parameters": "num_ctx 32768\ntemperature 0.7",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("custom-model", "http://localhost:11434")
|
||||
|
||||
assert result == 32768
|
||||
|
||||
def test_returns_none_for_non_ollama_server(self):
|
||||
"""Should return None if the server is not Ollama."""
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:1234")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_connection_error(self):
|
||||
"""Should return None if the server is unreachable."""
|
||||
with patch("agent.model_metadata.detect_local_server_type", side_effect=Exception("timeout")):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:11434")
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_on_404(self):
|
||||
"""Should return None if the model is not found."""
|
||||
mock_ctx, _ = _mock_httpx_client({}, status_code=404)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("nonexistent", "http://localhost:11434")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""Should strip 'local:' prefix from model name before querying."""
|
||||
show_data = {
|
||||
"model_info": {"qwen2.context_length": 32768},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, mock_client = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("local:qwen2.5:7b", "http://localhost:11434/v1")
|
||||
|
||||
# Verify the post was called with stripped name (no "local:" prefix)
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["name"] == "qwen2.5:7b" or call_args[0][1] is not None
|
||||
assert result == 32768
|
||||
|
||||
def test_handles_qwen2_architecture_key(self):
|
||||
"""Different model architectures use different key prefixes in model_info."""
|
||||
show_data = {
|
||||
"model_info": {"qwen2.context_length": 65536},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("qwen2.5:32b", "http://localhost:11434")
|
||||
|
||||
assert result == 65536
|
||||
|
||||
def test_returns_none_when_model_info_empty(self):
|
||||
"""Should return None if model_info has no context_length key."""
|
||||
show_data = {
|
||||
"model_info": {"llama.embedding_length": 4096},
|
||||
"parameters": "",
|
||||
}
|
||||
mock_ctx, _ = _mock_httpx_client(show_data)
|
||||
|
||||
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"):
|
||||
import httpx
|
||||
with patch.object(httpx, "Client", return_value=mock_ctx):
|
||||
result = query_ollama_num_ctx("model", "http://localhost:11434")
|
||||
|
||||
assert result is None
|
||||
Reference in New Issue
Block a user