When LM Studio has a model loaded with a custom context size (e.g., 122K), prefer that over the model's max_context_length (e.g., 1M). This makes the TUI status bar show the actual runtime context window.
494 lines
21 KiB
Python
494 lines
21 KiB
Python
"""Tests for _query_local_context_length and the local server fallback in
|
|
get_model_context_length.
|
|
|
|
All tests use synthetic inputs — no filesystem or live server required.
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import json
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
|
|
|
import pytest
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _query_local_context_length — unit tests with mocked httpx
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestQueryLocalContextLengthOllama:
|
|
"""_query_local_context_length with server_type == 'ollama'."""
|
|
|
|
def _make_resp(self, status_code, body):
|
|
resp = MagicMock()
|
|
resp.status_code = status_code
|
|
resp.json.return_value = body
|
|
return resp
|
|
|
|
def test_ollama_model_info_context_length(self):
|
|
"""Reads context length from model_info dict in /api/show response."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
show_resp = self._make_resp(200, {
|
|
"model_info": {"llama.context_length": 131072}
|
|
})
|
|
models_resp = self._make_resp(404, {})
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = show_resp
|
|
client_mock.get.return_value = models_resp
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
assert result == 131072
|
|
|
|
def test_ollama_parameters_num_ctx(self):
|
|
"""Falls back to num_ctx in parameters string when model_info lacks context_length."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
show_resp = self._make_resp(200, {
|
|
"model_info": {},
|
|
"parameters": "num_ctx 32768\ntemperature 0.7\n"
|
|
})
|
|
models_resp = self._make_resp(404, {})
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = show_resp
|
|
client_mock.get.return_value = models_resp
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
|
|
|
assert result == 32768
|
|
|
|
def test_ollama_show_404_falls_through(self):
|
|
"""When /api/show returns 404, falls through to /v1/models/{model}."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
show_resp = self._make_resp(404, {})
|
|
model_detail_resp = self._make_resp(200, {"max_model_len": 65536})
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = show_resp
|
|
client_mock.get.return_value = model_detail_resp
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="ollama"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("some-model", "http://localhost:11434/v1")
|
|
|
|
assert result == 65536
|
|
|
|
|
|
class TestQueryLocalContextLengthVllm:
|
|
"""_query_local_context_length with vLLM-style /v1/models/{model} response."""
|
|
|
|
def _make_resp(self, status_code, body):
|
|
resp = MagicMock()
|
|
resp.status_code = status_code
|
|
resp.json.return_value = body
|
|
return resp
|
|
|
|
def test_vllm_max_model_len(self):
|
|
"""Reads max_model_len from /v1/models/{model} response."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
detail_resp = self._make_resp(200, {"id": "omnicoder-9b", "max_model_len": 100000})
|
|
list_resp = self._make_resp(404, {})
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = self._make_resp(404, {})
|
|
client_mock.get.return_value = detail_resp
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:8000/v1")
|
|
|
|
assert result == 100000
|
|
|
|
def test_vllm_context_length_key(self):
|
|
"""Reads context_length from /v1/models/{model} response."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
detail_resp = self._make_resp(200, {"id": "some-model", "context_length": 32768})
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = self._make_resp(404, {})
|
|
client_mock.get.return_value = detail_resp
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="vllm"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("some-model", "http://localhost:8000/v1")
|
|
|
|
assert result == 32768
|
|
|
|
|
|
class TestQueryLocalContextLengthModelsList:
|
|
"""_query_local_context_length: falls back to /v1/models list."""
|
|
|
|
def _make_resp(self, status_code, body):
|
|
resp = MagicMock()
|
|
resp.status_code = status_code
|
|
resp.json.return_value = body
|
|
return resp
|
|
|
|
def test_models_list_max_model_len(self):
|
|
"""Finds context length for model in /v1/models list."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
detail_resp = self._make_resp(404, {})
|
|
list_resp = self._make_resp(200, {
|
|
"data": [
|
|
{"id": "other-model", "max_model_len": 4096},
|
|
{"id": "omnicoder-9b", "max_model_len": 131072},
|
|
]
|
|
})
|
|
|
|
call_count = [0]
|
|
def side_effect(url, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return detail_resp # /v1/models/omnicoder-9b
|
|
return list_resp # /v1/models
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = self._make_resp(404, {})
|
|
client_mock.get.side_effect = side_effect
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
|
|
|
assert result == 131072
|
|
|
|
def test_models_list_model_not_found_returns_none(self):
|
|
"""Returns None when model is not in the /v1/models list."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
detail_resp = self._make_resp(404, {})
|
|
list_resp = self._make_resp(200, {
|
|
"data": [{"id": "other-model", "max_model_len": 4096}]
|
|
})
|
|
|
|
call_count = [0]
|
|
def side_effect(url, **kwargs):
|
|
call_count[0] += 1
|
|
if call_count[0] == 1:
|
|
return detail_resp
|
|
return list_resp
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = self._make_resp(404, {})
|
|
client_mock.get.side_effect = side_effect
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:1234")
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestQueryLocalContextLengthLmStudio:
|
|
"""_query_local_context_length with LM Studio native /api/v1/models response."""
|
|
|
|
def _make_resp(self, status_code, body):
|
|
resp = MagicMock()
|
|
resp.status_code = status_code
|
|
resp.json.return_value = body
|
|
return resp
|
|
|
|
def _make_client(self, native_resp, detail_resp, list_resp):
|
|
"""Build a mock httpx.Client with sequenced GET responses."""
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.return_value = self._make_resp(404, {})
|
|
|
|
responses = [native_resp, detail_resp, list_resp]
|
|
call_idx = [0]
|
|
|
|
def get_side_effect(url, **kwargs):
|
|
idx = call_idx[0]
|
|
call_idx[0] += 1
|
|
if idx < len(responses):
|
|
return responses[idx]
|
|
return self._make_resp(404, {})
|
|
|
|
client_mock.get.side_effect = get_side_effect
|
|
return client_mock
|
|
|
|
def test_lmstudio_exact_key_match(self):
|
|
"""Reads max_context_length when key matches exactly."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
native_resp = self._make_resp(200, {
|
|
"models": [
|
|
{"key": "nvidia/nvidia-nemotron-super-49b-v1", "id": "nvidia/nvidia-nemotron-super-49b-v1",
|
|
"max_context_length": 131072},
|
|
]
|
|
})
|
|
client_mock = self._make_client(
|
|
native_resp,
|
|
self._make_resp(404, {}),
|
|
self._make_resp(404, {}),
|
|
)
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length(
|
|
"nvidia/nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
|
)
|
|
|
|
assert result == 131072
|
|
|
|
def test_lmstudio_slug_only_matches_key_with_publisher_prefix(self):
|
|
"""Fuzzy match: bare model slug matches key that includes publisher prefix.
|
|
|
|
When the user configures the model as "local:nvidia-nemotron-super-49b-v1"
|
|
(slug only, no publisher), but LM Studio's native API stores it as
|
|
"nvidia/nvidia-nemotron-super-49b-v1", the lookup must still succeed.
|
|
"""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
native_resp = self._make_resp(200, {
|
|
"models": [
|
|
{"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
|
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
|
"max_context_length": 131072},
|
|
]
|
|
})
|
|
client_mock = self._make_client(
|
|
native_resp,
|
|
self._make_resp(404, {}),
|
|
self._make_resp(404, {}),
|
|
)
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
# Model passed in is just the slug after stripping "local:" prefix
|
|
result = _query_local_context_length(
|
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
|
)
|
|
|
|
assert result == 131072
|
|
|
|
def test_lmstudio_v1_models_list_slug_fuzzy_match(self):
|
|
"""Fuzzy match also works for /v1/models list when exact match fails.
|
|
|
|
LM Studio's OpenAI-compat /v1/models returns id like
|
|
"nvidia/nvidia-nemotron-super-49b-v1" — must match bare slug.
|
|
"""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
# native /api/v1/models: no match
|
|
native_resp = self._make_resp(404, {})
|
|
# /v1/models/{model}: no match
|
|
detail_resp = self._make_resp(404, {})
|
|
# /v1/models list: model found with publisher prefix, includes context_length
|
|
list_resp = self._make_resp(200, {
|
|
"data": [
|
|
{"id": "nvidia/nvidia-nemotron-super-49b-v1", "context_length": 131072},
|
|
]
|
|
})
|
|
client_mock = self._make_client(native_resp, detail_resp, list_resp)
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length(
|
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
|
)
|
|
|
|
assert result == 131072
|
|
|
|
def test_lmstudio_loaded_instances_context_length(self):
|
|
"""Reads active context_length from loaded_instances when max_context_length absent."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
native_resp = self._make_resp(200, {
|
|
"models": [
|
|
{
|
|
"key": "nvidia/nvidia-nemotron-super-49b-v1",
|
|
"id": "nvidia/nvidia-nemotron-super-49b-v1",
|
|
"loaded_instances": [
|
|
{"config": {"context_length": 65536}},
|
|
],
|
|
},
|
|
]
|
|
})
|
|
client_mock = self._make_client(
|
|
native_resp,
|
|
self._make_resp(404, {}),
|
|
self._make_resp(404, {}),
|
|
)
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length(
|
|
"nvidia-nemotron-super-49b-v1", "http://192.168.1.22:1234/v1"
|
|
)
|
|
|
|
assert result == 65536
|
|
|
|
def test_lmstudio_loaded_instance_beats_max_context_length(self):
|
|
"""loaded_instances context_length takes priority over max_context_length.
|
|
|
|
LM Studio may show max_context_length=1_048_576 (theoretical model max)
|
|
while the actual loaded context is 122_651 (runtime setting). The loaded
|
|
value is the real constraint and must be preferred.
|
|
"""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
native_resp = self._make_resp(200, {
|
|
"models": [
|
|
{
|
|
"key": "nvidia/nvidia-nemotron-3-nano-4b",
|
|
"id": "nvidia/nvidia-nemotron-3-nano-4b",
|
|
"max_context_length": 1_048_576,
|
|
"loaded_instances": [
|
|
{"config": {"context_length": 122_651}},
|
|
],
|
|
},
|
|
]
|
|
})
|
|
client_mock = self._make_client(
|
|
native_resp,
|
|
self._make_resp(404, {}),
|
|
self._make_resp(404, {}),
|
|
)
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value="lm-studio"), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length(
|
|
"nvidia-nemotron-3-nano-4b", "http://192.168.1.22:1234/v1"
|
|
)
|
|
|
|
assert result == 122_651, (
|
|
f"Expected loaded instance context (122651) but got {result}. "
|
|
"max_context_length (1048576) must not win over loaded_instances."
|
|
)
|
|
|
|
|
|
class TestQueryLocalContextLengthNetworkError:
|
|
"""_query_local_context_length handles network failures gracefully."""
|
|
|
|
def test_connection_error_returns_none(self):
|
|
"""Returns None when the server is unreachable."""
|
|
from agent.model_metadata import _query_local_context_length
|
|
|
|
client_mock = MagicMock()
|
|
client_mock.__enter__ = lambda s: client_mock
|
|
client_mock.__exit__ = MagicMock(return_value=False)
|
|
client_mock.post.side_effect = Exception("Connection refused")
|
|
client_mock.get.side_effect = Exception("Connection refused")
|
|
|
|
with patch("agent.model_metadata.detect_local_server_type", return_value=None), \
|
|
patch("httpx.Client", return_value=client_mock):
|
|
result = _query_local_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
assert result is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_model_context_length — integration-style tests with mocked helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestGetModelContextLengthLocalFallback:
|
|
"""get_model_context_length uses local server query before falling back to 2M."""
|
|
|
|
def test_local_endpoint_unknown_model_queries_server(self):
|
|
"""Unknown model on local endpoint gets ctx from server, not 2M default."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
|
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
|
patch("agent.model_metadata.save_context_length") as mock_save:
|
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
assert result == 131072
|
|
|
|
def test_local_endpoint_unknown_model_result_is_cached(self):
|
|
"""Context length returned from local server is persisted to cache."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
|
patch("agent.model_metadata._query_local_context_length", return_value=131072), \
|
|
patch("agent.model_metadata.save_context_length") as mock_save:
|
|
get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
mock_save.assert_called_once_with("omnicoder-9b", "http://localhost:11434/v1", 131072)
|
|
|
|
def test_local_endpoint_server_returns_none_falls_back_to_2m(self):
|
|
"""When local server returns None, still falls back to 2M probe tier."""
|
|
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.is_local_endpoint", return_value=True), \
|
|
patch("agent.model_metadata._query_local_context_length", return_value=None):
|
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
assert result == CONTEXT_PROBE_TIERS[0]
|
|
|
|
def test_non_local_endpoint_does_not_query_local_server(self):
|
|
"""For non-local endpoints, _query_local_context_length is not called."""
|
|
from agent.model_metadata import get_model_context_length, CONTEXT_PROBE_TIERS
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.is_local_endpoint", return_value=False), \
|
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
|
result = get_model_context_length(
|
|
"unknown-model", "https://some-cloud-api.example.com/v1"
|
|
)
|
|
|
|
mock_query.assert_not_called()
|
|
|
|
def test_cached_result_skips_local_query(self):
|
|
"""Cached context length is returned without querying the local server."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=65536), \
|
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
|
result = get_model_context_length("omnicoder-9b", "http://localhost:11434/v1")
|
|
|
|
assert result == 65536
|
|
mock_query.assert_not_called()
|
|
|
|
def test_no_base_url_does_not_query_local_server(self):
|
|
"""When base_url is empty, local server is not queried."""
|
|
from agent.model_metadata import get_model_context_length
|
|
|
|
with patch("agent.model_metadata.get_cached_context_length", return_value=None), \
|
|
patch("agent.model_metadata.fetch_endpoint_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata.fetch_model_metadata", return_value={}), \
|
|
patch("agent.model_metadata._query_local_context_length") as mock_query:
|
|
result = get_model_context_length("unknown-xyz-model", "")
|
|
|
|
mock_query.assert_not_called()
|