_show_status() now references self.provider and self._provider_source, added after the original PR was submitted.
162 lines
6.6 KiB
Python
162 lines
6.6 KiB
Python
"""Tests for the low context length warning in the CLI banner."""
|
|
|
|
import os
|
|
from types import SimpleNamespace
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
|
|
@pytest.fixture
|
|
def _isolate(tmp_path, monkeypatch):
|
|
"""Isolate HERMES_HOME so tests don't touch real config."""
|
|
home = tmp_path / ".hermes"
|
|
home.mkdir()
|
|
monkeypatch.setenv("HERMES_HOME", str(home))
|
|
|
|
|
|
@pytest.fixture
|
|
def cli_obj(_isolate):
|
|
"""Create a minimal HermesCLI instance for banner testing."""
|
|
with patch("cli.load_cli_config", return_value={
|
|
"display": {"tool_progress": "new"},
|
|
"terminal": {},
|
|
}), patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
from cli import HermesCLI
|
|
obj = HermesCLI.__new__(HermesCLI)
|
|
obj.model = "test-model"
|
|
obj.enabled_toolsets = ["hermes-core"]
|
|
obj.compact = False
|
|
obj.console = MagicMock()
|
|
obj.session_id = None
|
|
obj.api_key = "test"
|
|
obj.base_url = ""
|
|
obj.provider = "test"
|
|
obj._provider_source = None
|
|
# Mock agent with context compressor
|
|
obj.agent = SimpleNamespace(
|
|
context_compressor=SimpleNamespace(context_length=None)
|
|
)
|
|
return obj
|
|
|
|
|
|
class TestLowContextWarning:
|
|
"""Tests that the CLI warns about low context lengths."""
|
|
|
|
def test_no_warning_for_normal_context(self, cli_obj):
|
|
"""No warning when context is 32k+."""
|
|
cli_obj.agent.context_compressor.context_length = 32768
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
# Check that no yellow warning was printed
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 0
|
|
|
|
def test_warning_for_low_context(self, cli_obj):
|
|
"""Warning shown when context is 4096 (Ollama default)."""
|
|
cli_obj.agent.context_compressor.context_length = 4096
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 1
|
|
assert "4,096" in warning_calls[0]
|
|
|
|
def test_warning_for_2048_context(self, cli_obj):
|
|
"""Warning shown for 2048 tokens (common LM Studio default)."""
|
|
cli_obj.agent.context_compressor.context_length = 2048
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 1
|
|
|
|
def test_no_warning_at_boundary(self, cli_obj):
|
|
"""No warning at exactly 8192 — 8192 is borderline but included in warning."""
|
|
cli_obj.agent.context_compressor.context_length = 8192
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 1 # 8192 is still warned about
|
|
|
|
def test_no_warning_above_boundary(self, cli_obj):
|
|
"""No warning at 16384."""
|
|
cli_obj.agent.context_compressor.context_length = 16384
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 0
|
|
|
|
def test_ollama_specific_hint(self, cli_obj):
|
|
"""Ollama-specific fix shown when port 11434 detected."""
|
|
cli_obj.agent.context_compressor.context_length = 4096
|
|
cli_obj.base_url = "http://localhost:11434/v1"
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
ollama_hints = [c for c in calls if "OLLAMA_CONTEXT_LENGTH" in c]
|
|
assert len(ollama_hints) == 1
|
|
|
|
def test_lm_studio_specific_hint(self, cli_obj):
|
|
"""LM Studio-specific fix shown when port 1234 detected."""
|
|
cli_obj.agent.context_compressor.context_length = 2048
|
|
cli_obj.base_url = "http://localhost:1234/v1"
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
lms_hints = [c for c in calls if "LM Studio" in c]
|
|
assert len(lms_hints) == 1
|
|
|
|
def test_generic_hint_for_other_servers(self, cli_obj):
|
|
"""Generic fix shown for unknown servers."""
|
|
cli_obj.agent.context_compressor.context_length = 4096
|
|
cli_obj.base_url = "http://localhost:8080/v1"
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
generic_hints = [c for c in calls if "config.yaml" in c]
|
|
assert len(generic_hints) == 1
|
|
|
|
def test_no_warning_when_no_context_length(self, cli_obj):
|
|
"""No warning when context length is not yet known."""
|
|
cli_obj.agent.context_compressor.context_length = None
|
|
with patch("cli.get_tool_definitions", return_value=[]), \
|
|
patch("cli.build_welcome_banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 0
|
|
|
|
def test_compact_banner_does_not_crash_on_narrow_terminal(self, cli_obj):
|
|
"""Compact mode should still have ctx_len defined for warning logic."""
|
|
cli_obj.agent.context_compressor.context_length = 4096
|
|
|
|
with patch("shutil.get_terminal_size", return_value=os.terminal_size((70, 40))), \
|
|
patch("cli._build_compact_banner", return_value="compact banner"):
|
|
cli_obj.show_banner()
|
|
|
|
calls = [str(c) for c in cli_obj.console.print.call_args_list]
|
|
warning_calls = [c for c in calls if "too low" in c]
|
|
assert len(warning_calls) == 1
|