feat: smart context length probing with persistent caching + banner display
Replaces the unsafe 128K fallback for unknown models with a descending probe strategy (2M → 1M → 512K → 200K → 128K → 64K → 32K). When a context-length error occurs, the agent steps down tiers and retries. The discovered limit is cached per model+provider combo in ~/.hermes/context_length_cache.yaml so subsequent sessions skip probing. Also parses API error messages to extract the actual context limit (e.g. 'maximum context length is 32768 tokens') for instant resolution. The CLI banner now displays the context window size next to the model name (e.g. 'claude-opus-4 · 200K context · Nous Research'). Changes: - agent/model_metadata.py: CONTEXT_PROBE_TIERS, persistent cache (save/load/get), parse_context_limit_from_error(), get_next_probe_tier() - agent/context_compressor.py: accepts base_url, passes to metadata - run_agent.py: step-down logic in context error handler, caches on success - cli.py + hermes_cli/banner.py: context length in welcome banner - tests: 22 new tests for probing, parsing, and caching Addresses #132. PR #319's approach (8K default) rejected — too conservative.
This commit is contained in:
@@ -1,13 +1,22 @@
|
||||
"""Tests for agent/model_metadata.py — token estimation and context lengths."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.model_metadata import (
|
||||
CONTEXT_PROBE_TIERS,
|
||||
DEFAULT_CONTEXT_LENGTHS,
|
||||
estimate_tokens_rough,
|
||||
estimate_messages_tokens_rough,
|
||||
get_model_context_length,
|
||||
get_next_probe_tier,
|
||||
get_cached_context_length,
|
||||
parse_context_limit_from_error,
|
||||
save_context_length,
|
||||
fetch_model_metadata,
|
||||
_MODEL_CACHE_TTL,
|
||||
)
|
||||
@@ -101,10 +110,10 @@ class TestGetModelContextLength:
|
||||
assert result == 200000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_unknown_model_returns_128k(self, mock_fetch):
|
||||
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
result = get_model_context_length("unknown/never-heard-of-this")
|
||||
assert result == 128000
|
||||
assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_partial_match_in_defaults(self, mock_fetch):
|
||||
@@ -154,3 +163,123 @@ class TestFetchModelMetadata:
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
assert result == {}
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Context probe tiers
|
||||
# =========================================================================
|
||||
|
||||
class TestContextProbeTiers:
|
||||
def test_tiers_descending(self):
|
||||
for i in range(len(CONTEXT_PROBE_TIERS) - 1):
|
||||
assert CONTEXT_PROBE_TIERS[i] > CONTEXT_PROBE_TIERS[i + 1]
|
||||
|
||||
def test_first_tier_is_2m(self):
|
||||
assert CONTEXT_PROBE_TIERS[0] == 2_000_000
|
||||
|
||||
def test_last_tier_is_32k(self):
|
||||
assert CONTEXT_PROBE_TIERS[-1] == 32_000
|
||||
|
||||
|
||||
class TestGetNextProbeTier:
|
||||
def test_from_2m(self):
|
||||
assert get_next_probe_tier(2_000_000) == 1_000_000
|
||||
|
||||
def test_from_1m(self):
|
||||
assert get_next_probe_tier(1_000_000) == 512_000
|
||||
|
||||
def test_from_128k(self):
|
||||
assert get_next_probe_tier(128_000) == 64_000
|
||||
|
||||
def test_from_32k_returns_none(self):
|
||||
assert get_next_probe_tier(32_000) is None
|
||||
|
||||
def test_from_below_min_returns_none(self):
|
||||
assert get_next_probe_tier(16_000) is None
|
||||
|
||||
def test_from_arbitrary_value(self):
|
||||
# 300K is between 512K and 200K, should return 200K
|
||||
assert get_next_probe_tier(300_000) == 200_000
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Error message parsing
|
||||
# =========================================================================
|
||||
|
||||
class TestParseContextLimitFromError:
|
||||
def test_openai_format(self):
|
||||
msg = "This model's maximum context length is 32768 tokens. However, your messages resulted in 45000 tokens."
|
||||
assert parse_context_limit_from_error(msg) == 32768
|
||||
|
||||
def test_context_length_exceeded(self):
|
||||
msg = "context_length_exceeded: maximum context length is 131072"
|
||||
assert parse_context_limit_from_error(msg) == 131072
|
||||
|
||||
def test_context_size_exceeded(self):
|
||||
msg = "Maximum context size 65536 exceeded"
|
||||
assert parse_context_limit_from_error(msg) == 65536
|
||||
|
||||
def test_no_limit_in_message(self):
|
||||
msg = "Something went wrong with the API"
|
||||
assert parse_context_limit_from_error(msg) is None
|
||||
|
||||
def test_unreasonable_number_rejected(self):
|
||||
msg = "context length is 42 tokens" # too small
|
||||
assert parse_context_limit_from_error(msg) is None
|
||||
|
||||
def test_ollama_format(self):
|
||||
msg = "Context size has been exceeded. Maximum context size is 32768"
|
||||
assert parse_context_limit_from_error(msg) == 32768
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Persistent context length cache
|
||||
# =========================================================================
|
||||
|
||||
class TestContextLengthCache:
|
||||
def test_save_and_load(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("test/model", "http://localhost:8080/v1", 32768)
|
||||
result = get_cached_context_length("test/model", "http://localhost:8080/v1")
|
||||
assert result == 32768
|
||||
|
||||
def test_missing_cache_returns_none(self, tmp_path):
|
||||
cache_file = tmp_path / "nonexistent.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
assert get_cached_context_length("test/model", "http://x") is None
|
||||
|
||||
def test_multiple_models_cached(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("model-a", "http://a", 64000)
|
||||
save_context_length("model-b", "http://b", 128000)
|
||||
assert get_cached_context_length("model-a", "http://a") == 64000
|
||||
assert get_cached_context_length("model-b", "http://b") == 128000
|
||||
|
||||
def test_same_model_different_providers(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("llama-3", "http://local:8080", 32768)
|
||||
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
|
||||
assert get_cached_context_length("llama-3", "http://local:8080") == 32768
|
||||
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
|
||||
|
||||
def test_idempotent_save(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("model", "http://x", 32768)
|
||||
save_context_length("model", "http://x", 32768) # same value
|
||||
with open(cache_file) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert len(data["context_lengths"]) == 1
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
|
||||
"""Cached context length should be used before API or defaults."""
|
||||
mock_fetch.return_value = {}
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("unknown/model", "http://local", 65536)
|
||||
result = get_model_context_length("unknown/model", base_url="http://local")
|
||||
assert result == 65536
|
||||
|
||||
Reference in New Issue
Block a user