Replaces the unsafe 128K fallback for unknown models with a descending probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a context-length error occurs, the agent steps down tiers and retries. The discovered limit is cached per model+provider combo in ~/.hermes/context_length_cache.yaml so subsequent sessions skip probing. Also parses API error messages to extract the actual context limit (e.g. 'maximum context length is 32768 tokens') for instant resolution. The CLI banner now displays the context window size next to the model name (e.g. 'claude-opus-4 · 200K context · Nous Research'). Changes: - agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache (save/load/get), parse_context_limit_from_error(), get_next_probe_tier() - agent/context_compressor.py: accepts base_url, passes to metadata - run_agent.py: step-down logic in context error handler, caches on success - cli.py + hermes_cli/banner.py: context length in welcome banner - tests: 22 new tests for probing, parsing, and caching Addresses #132. PR #319's approach (8K default) rejected — too conservative.
286 lines
11 KiB
Python
286 lines
11 KiB
Python
"""Tests for agent/model_metadata.py — token estimation and context lengths."""
|
|
|
|
import os
|
|
import tempfile
|
|
|
|
import pytest
|
|
import yaml
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from agent.model_metadata import (
|
|
CONTEXT_PROBE_TIERS,
|
|
DEFAULT_CONTEXT_LENGTHS,
|
|
estimate_tokens_rough,
|
|
estimate_messages_tokens_rough,
|
|
get_model_context_length,
|
|
get_next_probe_tier,
|
|
get_cached_context_length,
|
|
parse_context_limit_from_error,
|
|
save_context_length,
|
|
fetch_model_metadata,
|
|
_MODEL_CACHE_TTL,
|
|
)
|
|
|
|
|
|
# =========================================================================
|
|
# Token estimation
|
|
# =========================================================================
|
|
|
|
class TestEstimateTokensRough:
|
|
def test_empty_string(self):
|
|
assert estimate_tokens_rough("") == 0
|
|
|
|
def test_none_returns_zero(self):
|
|
assert estimate_tokens_rough(None) == 0
|
|
|
|
def test_known_length(self):
|
|
# 400 chars / 4 = 100 tokens
|
|
text = "a" * 400
|
|
assert estimate_tokens_rough(text) == 100
|
|
|
|
def test_short_text(self):
|
|
# "hello" = 5 chars -> 5 // 4 = 1
|
|
assert estimate_tokens_rough("hello") == 1
|
|
|
|
def test_proportional(self):
|
|
short = estimate_tokens_rough("hello world")
|
|
long = estimate_tokens_rough("hello world " * 100)
|
|
assert long > short
|
|
|
|
|
|
class TestEstimateMessagesTokensRough:
|
|
def test_empty_list(self):
|
|
assert estimate_messages_tokens_rough([]) == 0
|
|
|
|
def test_single_message(self):
|
|
msgs = [{"role": "user", "content": "a" * 400}]
|
|
result = estimate_messages_tokens_rough(msgs)
|
|
assert result > 0
|
|
|
|
def test_multiple_messages(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there, how can I help?"},
|
|
]
|
|
result = estimate_messages_tokens_rough(msgs)
|
|
assert result > 0
|
|
|
|
|
|
# =========================================================================
|
|
# Default context lengths
|
|
# =========================================================================
|
|
|
|
class TestDefaultContextLengths:
|
|
def test_claude_models_200k(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "claude" in key:
|
|
assert value == 200000, f"{key} should be 200000"
|
|
|
|
def test_gpt4_models_128k(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gpt-4" in key:
|
|
assert value == 128000, f"{key} should be 128000"
|
|
|
|
def test_gemini_models_1m(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gemini" in key:
|
|
assert value == 1048576, f"{key} should be 1048576"
|
|
|
|
def test_all_values_positive(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
assert value > 0, f"{key} has non-positive context length"
|
|
|
|
|
|
# =========================================================================
|
|
# get_model_context_length (with mocked API)
|
|
# =========================================================================
|
|
|
|
class TestGetModelContextLength:
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_known_model_from_api(self, mock_fetch):
|
|
mock_fetch.return_value = {
|
|
"test/model": {"context_length": 32000}
|
|
}
|
|
assert get_model_context_length("test/model") == 32000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_fallback_to_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {} # API returns nothing
|
|
result = get_model_context_length("anthropic/claude-sonnet-4")
|
|
assert result == 200000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
result = get_model_context_length("unknown/never-heard-of-this")
|
|
assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_partial_match_in_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
# "gpt-4o" is a substring match for "openai/gpt-4o"
|
|
result = get_model_context_length("openai/gpt-4o")
|
|
assert result == 128000
|
|
|
|
|
|
# =========================================================================
|
|
# fetch_model_metadata (cache behavior)
|
|
# =========================================================================
|
|
|
|
class TestFetchModelMetadata:
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_caches_result(self, mock_get):
|
|
import agent.model_metadata as mm
|
|
# Reset cache
|
|
mm._model_metadata_cache = {}
|
|
mm._model_metadata_cache_time = 0
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{"id": "test/model", "context_length": 99999, "name": "Test Model"}
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
# First call fetches
|
|
result1 = fetch_model_metadata(force_refresh=True)
|
|
assert "test/model" in result1
|
|
assert mock_get.call_count == 1
|
|
|
|
# Second call uses cache
|
|
result2 = fetch_model_metadata()
|
|
assert "test/model" in result2
|
|
assert mock_get.call_count == 1 # Not called again
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_api_failure_returns_empty(self, mock_get):
|
|
import agent.model_metadata as mm
|
|
mm._model_metadata_cache = {}
|
|
mm._model_metadata_cache_time = 0
|
|
|
|
mock_get.side_effect = Exception("Network error")
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
assert result == {}
|
|
|
|
|
|
# =========================================================================
|
|
# Context probe tiers
|
|
# =========================================================================
|
|
|
|
class TestContextProbeTiers:
|
|
def test_tiers_descending(self):
|
|
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
|
|
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
|
|
|
|
def test_first_tier_is_2m(self):
|
|
assert CONTEXT_PROBE_TIERS[0] == 2_000_000
|
|
|
|
def test_last_tier_is_32k(self):
|
|
assert CONTEXT_PROBE_TIERS[-1] == 32_000
|
|
|
|
|
|
class TestGetNextProbeTier:
|
|
def test_from_2m(self):
|
|
assert get_next_probe_tier(2_000_000) == 1_000_000
|
|
|
|
def test_from_1m(self):
|
|
assert get_next_probe_tier(1_000_000) == 512_000
|
|
|
|
def test_from_128k(self):
|
|
assert get_next_probe_tier(128_000) == 64_000
|
|
|
|
def test_from_32k_returns_none(self):
|
|
assert get_next_probe_tier(32_000) is None
|
|
|
|
def test_from_below_min_returns_none(self):
|
|
assert get_next_probe_tier(16_000) is None
|
|
|
|
def test_from_arbitrary_value(self):
|
|
# 300K is between 512K and 200K, should return 200K
|
|
assert get_next_probe_tier(300_000) == 200_000
|
|
|
|
|
|
# =========================================================================
|
|
# Error message parsing
|
|
# =========================================================================
|
|
|
|
class TestParseContextLimitFromError:
|
|
def test_openai_format(self):
|
|
msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
|
|
assert parse_context_limit_from_error(msg) == 32768
|
|
|
|
def test_context_length_exceeded(self):
|
|
msg = "context_length_exceeded: maximum context length is 131072"
|
|
assert parse_context_limit_from_error(msg) == 131072
|
|
|
|
def test_context_size_exceeded(self):
|
|
msg = "Maximum context size 65536 exceeded"
|
|
assert parse_context_limit_from_error(msg) == 65536
|
|
|
|
def test_no_limit_in_message(self):
|
|
msg = "Something went wrong with the API"
|
|
assert parse_context_limit_from_error(msg) is None
|
|
|
|
def test_unreasonable_number_rejected(self):
|
|
msg = "context length is 42 tokens" # too small
|
|
assert parse_context_limit_from_error(msg) is None
|
|
|
|
def test_ollama_format(self):
|
|
msg = "Context size has been exceeded. Maximum context size is 32768"
|
|
assert parse_context_limit_from_error(msg) == 32768
|
|
|
|
|
|
# =========================================================================
|
|
# Persistent context length cache
|
|
# =========================================================================
|
|
|
|
class TestContextLengthCache:
|
|
def test_save_and_load(self, tmp_path):
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("test/model", "http://localhost:8080/v1", 32768)
|
|
result = get_cached_context_length("test/model", "http://localhost:8080/v1")
|
|
assert result == 32768
|
|
|
|
def test_missing_cache_returns_none(self, tmp_path):
|
|
cache_file = tmp_path / "nonexistent.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
assert get_cached_context_length("test/model", "http://x") is None
|
|
|
|
def test_multiple_models_cached(self, tmp_path):
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("model-a", "http://a", 64000)
|
|
save_context_length("model-b", "http://b", 128000)
|
|
assert get_cached_context_length("model-a", "http://a") == 64000
|
|
assert get_cached_context_length("model-b", "http://b") == 128000
|
|
|
|
def test_same_model_different_providers(self, tmp_path):
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("llama-3", "http://local:8080", 32768)
|
|
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
|
|
assert get_cached_context_length("llama-3", "http://local:8080") == 32768
|
|
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
|
|
|
|
def test_idempotent_save(self, tmp_path):
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("model", "http://x", 32768)
|
|
save_context_length("model", "http://x", 32768) # same value
|
|
with open(cache_file) as f:
|
|
data = yaml.safe_load(f)
|
|
assert len(data["context_lengths"]) == 1
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
|
|
"""Cached context length should be used before API or defaults."""
|
|
mock_fetch.return_value = {}
|
|
cache_file = tmp_path / "context_length_cache.yaml"
|
|
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
|
save_context_length("unknown/model", "http://local", 65536)
|
|
result = get_model_context_length("unknown/model", base_url="http://local")
|
|
assert result == 65536
|