diff --git a/agent/model_metadata.py b/agent/model_metadata.py index cf3799799..319b9bdcd 100644 --- a/agent/model_metadata.py +++ b/agent/model_metadata.py @@ -158,6 +158,8 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]: r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})', r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})', r'(\d{4,})\s*(?:token)?\s*(?:context|limit)', + r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum" + r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum" ] for pattern in patterns: match = re.search(pattern, error_lower) diff --git a/tests/agent/test_model_metadata.py b/tests/agent/test_model_metadata.py index ffc98cb26..b58e6a2e5 100644 --- a/tests/agent/test_model_metadata.py +++ b/tests/agent/test_model_metadata.py @@ -1,10 +1,22 @@ -"""Tests for agent/model_metadata.py — token estimation and context lengths.""" +"""Tests for agent/model_metadata.py — token estimation, context lengths, +probing, caching, and error parsing. + +Coverage levels: + Token estimation — concrete value assertions, edge cases + Context length lookup — resolution order, fuzzy match, cache priority + API metadata fetch — caching, TTL, canonical slugs, stale fallback + Probe tiers — descending, boundaries, extreme inputs + Error parsing — OpenAI, Ollama, Anthropic, edge cases + Persistent cache — save/load, corruption, update, provider isolation +""" import os +import time import tempfile import pytest import yaml +from pathlib import Path from unittest.mock import patch, MagicMock from agent.model_metadata import ( @@ -34,12 +46,9 @@ class TestEstimateTokensRough: assert estimate_tokens_rough(None) == 0 def test_known_length(self): - # 400 chars / 4 = 100 tokens - text = "a" * 400 - assert estimate_tokens_rough(text) == 100 + assert estimate_tokens_rough("a" * 400) == 100 def test_short_text(self): - # "hello" = 5 chars -> 5 // 4 = 1 assert estimate_tokens_rough("hello") == 1 def test_proportional(self): @@ -47,23 +56,48 @@ class TestEstimateTokensRough: long = estimate_tokens_rough("hello world " * 100) assert long > short + def test_unicode_multibyte(self): + """Unicode chars are still 1 Python char each — 4 chars/token holds.""" + text = "你好世界" # 4 CJK characters + assert estimate_tokens_rough(text) == 1 + class TestEstimateMessagesTokensRough: def test_empty_list(self): assert estimate_messages_tokens_rough([]) == 0 - def test_single_message(self): - msgs = [{"role": "user", "content": "a" * 400}] - result = estimate_messages_tokens_rough(msgs) - assert result > 0 + def test_single_message_concrete_value(self): + """Verify against known str(msg) length.""" + msg = {"role": "user", "content": "a" * 400} + result = estimate_messages_tokens_rough([msg]) + expected = len(str(msg)) // 4 + assert result == expected - def test_multiple_messages(self): + def test_multiple_messages_additive(self): msgs = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there, how can I help?"}, ] result = estimate_messages_tokens_rough(msgs) + expected = sum(len(str(m)) for m in msgs) // 4 + assert result == expected + + def test_tool_call_message(self): + """Tool call messages with no 'content' key still contribute tokens.""" + msg = {"role": "assistant", "content": None, + "tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]} + result = estimate_messages_tokens_rough([msg]) assert result > 0 + assert result == len(str(msg)) // 4 + + def test_message_with_list_content(self): + """Vision messages with multimodal content arrays.""" + msg = {"role": "user", "content": [ + {"type": "text", "text": "describe"}, + {"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}} + ]} + result = estimate_messages_tokens_rough([msg]) + assert result == len(str(msg)) // 4 # ========================================================================= @@ -90,9 +124,12 @@ class TestDefaultContextLengths: for key, value in DEFAULT_CONTEXT_LENGTHS.items(): assert value > 0, f"{key} has non-positive context length" + def test_dict_is_not_empty(self): + assert len(DEFAULT_CONTEXT_LENGTHS) >= 10 + # ========================================================================= -# get_model_context_length (with mocked API) +# get_model_context_length — resolution order # ========================================================================= class TestGetModelContextLength: @@ -105,62 +142,146 @@ class TestGetModelContextLength: @patch("agent.model_metadata.fetch_model_metadata") def test_fallback_to_defaults(self, mock_fetch): - mock_fetch.return_value = {} # API returns nothing - result = get_model_context_length("anthropic/claude-sonnet-4") - assert result == 200000 + mock_fetch.return_value = {} + assert get_model_context_length("anthropic/claude-sonnet-4") == 200000 @patch("agent.model_metadata.fetch_model_metadata") def test_unknown_model_returns_first_probe_tier(self, mock_fetch): mock_fetch.return_value = {} - result = get_model_context_length("unknown/never-heard-of-this") - assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error + assert get_model_context_length("unknown/never-heard-of-this") == CONTEXT_PROBE_TIERS[0] @patch("agent.model_metadata.fetch_model_metadata") def test_partial_match_in_defaults(self, mock_fetch): mock_fetch.return_value = {} - # "gpt-4o" is a substring match for "openai/gpt-4o" - result = get_model_context_length("openai/gpt-4o") - assert result == 128000 + assert get_model_context_length("openai/gpt-4o") == 128000 + + @patch("agent.model_metadata.fetch_model_metadata") + def test_api_missing_context_length_key(self, mock_fetch): + """Model in API but without context_length → defaults to 128000.""" + mock_fetch.return_value = {"test/model": {"name": "Test"}} + assert get_model_context_length("test/model") == 128000 + + @patch("agent.model_metadata.fetch_model_metadata") + def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path): + """Persistent cache should be checked BEFORE API metadata.""" + mock_fetch.return_value = {"my/model": {"context_length": 999999}} + cache_file = tmp_path / "cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("my/model", "http://local", 32768) + result = get_model_context_length("my/model", base_url="http://local") + assert result == 32768 # cache wins over API's 999999 + + @patch("agent.model_metadata.fetch_model_metadata") + def test_no_base_url_skips_cache(self, mock_fetch, tmp_path): + """Without base_url, cache lookup is skipped.""" + mock_fetch.return_value = {} + cache_file = tmp_path / "cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("custom/model", "http://local", 32768) + # No base_url → cache skipped → falls to probe tier + result = get_model_context_length("custom/model") + assert result == CONTEXT_PROBE_TIERS[0] # ========================================================================= -# fetch_model_metadata (cache behavior) +# fetch_model_metadata — caching, TTL, slugs, failures # ========================================================================= class TestFetchModelMetadata: - @patch("agent.model_metadata.requests.get") - def test_caches_result(self, mock_get): + def _reset_cache(self): import agent.model_metadata as mm - # Reset cache mm._model_metadata_cache = {} mm._model_metadata_cache_time = 0 + @patch("agent.model_metadata.requests.get") + def test_caches_result(self, mock_get): + self._reset_cache() mock_response = MagicMock() mock_response.json.return_value = { - "data": [ - {"id": "test/model", "context_length": 99999, "name": "Test Model"} - ] + "data": [{"id": "test/model", "context_length": 99999, "name": "Test"}] } mock_response.raise_for_status = MagicMock() mock_get.return_value = mock_response - # First call fetches result1 = fetch_model_metadata(force_refresh=True) assert "test/model" in result1 assert mock_get.call_count == 1 - # Second call uses cache result2 = fetch_model_metadata() assert "test/model" in result2 - assert mock_get.call_count == 1 # Not called again + assert mock_get.call_count == 1 # cached @patch("agent.model_metadata.requests.get") - def test_api_failure_returns_empty(self, mock_get): + def test_api_failure_returns_empty_on_cold_cache(self, mock_get): + self._reset_cache() + mock_get.side_effect = Exception("Network error") + result = fetch_model_metadata(force_refresh=True) + assert result == {} + + @patch("agent.model_metadata.requests.get") + def test_api_failure_returns_stale_cache(self, mock_get): + """On API failure with existing cache, stale data is returned.""" import agent.model_metadata as mm - mm._model_metadata_cache = {} - mm._model_metadata_cache_time = 0 + mm._model_metadata_cache = {"old/model": {"context_length": 50000}} + mm._model_metadata_cache_time = 0 # expired mock_get.side_effect = Exception("Network error") + result = fetch_model_metadata(force_refresh=True) + assert "old/model" in result + assert result["old/model"]["context_length"] == 50000 + + @patch("agent.model_metadata.requests.get") + def test_canonical_slug_aliasing(self, mock_get): + """Models with canonical_slug get indexed under both IDs.""" + self._reset_cache() + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{ + "id": "anthropic/claude-3.5-sonnet:beta", + "canonical_slug": "anthropic/claude-3.5-sonnet", + "context_length": 200000, + "name": "Claude 3.5 Sonnet" + }] + } + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + + result = fetch_model_metadata(force_refresh=True) + # Both the original ID and canonical slug should work + assert "anthropic/claude-3.5-sonnet:beta" in result + assert "anthropic/claude-3.5-sonnet" in result + assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000 + + @patch("agent.model_metadata.requests.get") + def test_ttl_expiry_triggers_refetch(self, mock_get): + """Cache expires after _MODEL_CACHE_TTL seconds.""" + import agent.model_metadata as mm + self._reset_cache() + + mock_response = MagicMock() + mock_response.json.return_value = { + "data": [{"id": "m1", "context_length": 1000, "name": "M1"}] + } + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + + fetch_model_metadata(force_refresh=True) + assert mock_get.call_count == 1 + + # Simulate TTL expiry + mm._model_metadata_cache_time = time.time() - _MODEL_CACHE_TTL - 1 + fetch_model_metadata() + assert mock_get.call_count == 2 # refetched + + @patch("agent.model_metadata.requests.get") + def test_malformed_json_no_data_key(self, mock_get): + """API returns JSON without 'data' key — empty cache, no crash.""" + self._reset_cache() + mock_response = MagicMock() + mock_response.json.return_value = {"error": "something"} + mock_response.raise_for_status = MagicMock() + mock_get.return_value = mock_response + result = fetch_model_metadata(force_refresh=True) assert result == {} @@ -198,9 +319,15 @@ class TestGetNextProbeTier: assert get_next_probe_tier(16_000) is None def test_from_arbitrary_value(self): - # 300K is between 512K and 200K, should return 200K assert get_next_probe_tier(300_000) == 200_000 + def test_above_max_tier(self): + """Value above 2M should return 2M.""" + assert get_next_probe_tier(5_000_000) == 2_000_000 + + def test_zero_returns_none(self): + assert get_next_probe_tier(0) is None + # ========================================================================= # Error message parsing @@ -220,17 +347,35 @@ class TestParseContextLimitFromError: assert parse_context_limit_from_error(msg) == 65536 def test_no_limit_in_message(self): - msg = "Something went wrong with the API" - assert parse_context_limit_from_error(msg) is None + assert parse_context_limit_from_error("Something went wrong with the API") is None - def test_unreasonable_number_rejected(self): - msg = "context length is 42 tokens" # too small - assert parse_context_limit_from_error(msg) is None + def test_unreasonable_small_number_rejected(self): + assert parse_context_limit_from_error("context length is 42 tokens") is None def test_ollama_format(self): msg = "Context size has been exceeded. Maximum context size is 32768" assert parse_context_limit_from_error(msg) == 32768 + def test_anthropic_format(self): + msg = "prompt is too long: 250000 tokens > 200000 maximum" + # Should extract 200000 (the limit), not 250000 (the input size) + assert parse_context_limit_from_error(msg) == 200000 + + def test_lmstudio_format(self): + msg = "Error: context window of 4096 tokens exceeded" + assert parse_context_limit_from_error(msg) == 4096 + + def test_completely_unrelated_error(self): + assert parse_context_limit_from_error("Invalid API key") is None + + def test_empty_string(self): + assert parse_context_limit_from_error("") is None + + def test_number_outside_reasonable_range(self): + """Very large number (>10M) should be rejected.""" + msg = "maximum context length is 99999999999" + assert parse_context_limit_from_error(msg) is None + # ========================================================================= # Persistent context length cache @@ -238,11 +383,10 @@ class TestParseContextLimitFromError: class TestContextLengthCache: def test_save_and_load(self, tmp_path): - cache_file = tmp_path / "context_length_cache.yaml" + cache_file = tmp_path / "cache.yaml" with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): save_context_length("test/model", "http://localhost:8080/v1", 32768) - result = get_cached_context_length("test/model", "http://localhost:8080/v1") - assert result == 32768 + assert get_cached_context_length("test/model", "http://localhost:8080/v1") == 32768 def test_missing_cache_returns_none(self, tmp_path): cache_file = tmp_path / "nonexistent.yaml" @@ -250,7 +394,7 @@ class TestContextLengthCache: assert get_cached_context_length("test/model", "http://x") is None def test_multiple_models_cached(self, tmp_path): - cache_file = tmp_path / "context_length_cache.yaml" + cache_file = tmp_path / "cache.yaml" with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): save_context_length("model-a", "http://a", 64000) save_context_length("model-b", "http://b", 128000) @@ -258,7 +402,7 @@ class TestContextLengthCache: assert get_cached_context_length("model-b", "http://b") == 128000 def test_same_model_different_providers(self, tmp_path): - cache_file = tmp_path / "context_length_cache.yaml" + cache_file = tmp_path / "cache.yaml" with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): save_context_length("llama-3", "http://local:8080", 32768) save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072) @@ -266,20 +410,49 @@ class TestContextLengthCache: assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072 def test_idempotent_save(self, tmp_path): - cache_file = tmp_path / "context_length_cache.yaml" + cache_file = tmp_path / "cache.yaml" with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): save_context_length("model", "http://x", 32768) - save_context_length("model", "http://x", 32768) # same value + save_context_length("model", "http://x", 32768) with open(cache_file) as f: data = yaml.safe_load(f) assert len(data["context_lengths"]) == 1 + def test_update_existing_value(self, tmp_path): + """Saving a different value for the same key overwrites it.""" + cache_file = tmp_path / "cache.yaml" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length("model", "http://x", 128000) + save_context_length("model", "http://x", 64000) + assert get_cached_context_length("model", "http://x") == 64000 + + def test_corrupted_yaml_returns_empty(self, tmp_path): + """Corrupted cache file is handled gracefully.""" + cache_file = tmp_path / "cache.yaml" + cache_file.write_text("{{{{not valid yaml: [[[") + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + assert get_cached_context_length("model", "http://x") is None + + def test_wrong_structure_returns_none(self, tmp_path): + """YAML that loads but has wrong structure.""" + cache_file = tmp_path / "cache.yaml" + cache_file.write_text("just_a_string\n") + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + assert get_cached_context_length("model", "http://x") is None + @patch("agent.model_metadata.fetch_model_metadata") def test_cached_value_takes_priority(self, mock_fetch, tmp_path): - """Cached context length should be used before API or defaults.""" mock_fetch.return_value = {} - cache_file = tmp_path / "context_length_cache.yaml" + cache_file = tmp_path / "cache.yaml" with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): save_context_length("unknown/model", "http://local", 65536) - result = get_model_context_length("unknown/model", base_url="http://local") - assert result == 65536 + assert get_model_context_length("unknown/model", base_url="http://local") == 65536 + + def test_special_chars_in_model_name(self, tmp_path): + """Model names with colons, slashes, etc. don't break the cache.""" + cache_file = tmp_path / "cache.yaml" + model = "anthropic/claude-3.5-sonnet:beta" + url = "https://api.example.com/v1" + with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file): + save_context_length(model, url, 200000) + assert get_cached_context_length(model, url) == 200000 diff --git a/tests/tools/test_web_tools_config.py b/tests/tools/test_web_tools_config.py index f1da27823..4bc49166f 100644 --- a/tests/tools/test_web_tools_config.py +++ b/tests/tools/test_web_tools_config.py @@ -1,4 +1,9 @@ -"""Tests for Firecrawl client configuration.""" +"""Tests for Firecrawl client configuration and singleton behavior. + +Coverage: + _get_firecrawl_client() — configuration matrix, singleton caching, + constructor failure recovery, return value verification, edge cases. +""" import os import pytest @@ -6,69 +11,109 @@ from unittest.mock import patch, MagicMock class TestFirecrawlClientConfig: - """Test suite for Firecrawl client initialization with API URL support.""" + """Test suite for Firecrawl client initialization.""" - def teardown_method(self): - """Reset client between tests.""" + def setup_method(self): + """Reset client and env vars before each test.""" import tools.web_tools - tools.web_tools._firecrawl_client = None - - def _clear_firecrawl_env(self): - """Remove Firecrawl env vars so tests start clean.""" for key in ("FIRECRAWL_API_KEY", "FIRECRAWL_API_URL"): os.environ.pop(key, None) - def test_client_with_api_key_only(self): - """Test client initialization with only API key (cloud mode).""" - self._clear_firecrawl_env() - with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "test-key"}, clear=False): - with patch("tools.web_tools.Firecrawl") as mock_firecrawl: + def teardown_method(self): + """Reset client after each test.""" + import tools.web_tools + tools.web_tools._firecrawl_client = None + for key in ("FIRECRAWL_API_KEY", "FIRECRAWL_API_URL"): + os.environ.pop(key, None) + + # ── Configuration matrix ───────────────────────────────────────── + + def test_cloud_mode_key_only(self): + """API key without URL → cloud Firecrawl.""" + with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + with patch("tools.web_tools.Firecrawl") as mock_fc: from tools.web_tools import _get_firecrawl_client + result = _get_firecrawl_client() + mock_fc.assert_called_once_with(api_key="fc-test") + assert result is mock_fc.return_value - _get_firecrawl_client() - mock_firecrawl.assert_called_once_with(api_key="test-key") - - def test_client_with_api_key_and_url(self): - """Test client initialization with API key and custom URL.""" - self._clear_firecrawl_env() - with patch.dict( - os.environ, - { - "FIRECRAWL_API_KEY": "test-key", - "FIRECRAWL_API_URL": "http://localhost:3002", - }, - clear=False, - ): - with patch("tools.web_tools.Firecrawl") as mock_firecrawl: + def test_self_hosted_with_key(self): + """Both key + URL → self-hosted with auth.""" + with patch.dict(os.environ, { + "FIRECRAWL_API_KEY": "fc-test", + "FIRECRAWL_API_URL": "http://localhost:3002", + }): + with patch("tools.web_tools.Firecrawl") as mock_fc: from tools.web_tools import _get_firecrawl_client - - _get_firecrawl_client() - mock_firecrawl.assert_called_once_with( - api_key="test-key", api_url="http://localhost:3002" + result = _get_firecrawl_client() + mock_fc.assert_called_once_with( + api_key="fc-test", api_url="http://localhost:3002" ) + assert result is mock_fc.return_value - def test_client_with_url_only_no_key(self): - """Self-hosted mode: URL without API key should work.""" - self._clear_firecrawl_env() - with patch.dict( - os.environ, - {"FIRECRAWL_API_URL": "http://localhost:3002"}, - clear=False, - ): - with patch("tools.web_tools.Firecrawl") as mock_firecrawl: + def test_self_hosted_no_key(self): + """URL only, no key → self-hosted without auth.""" + with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}): + with patch("tools.web_tools.Firecrawl") as mock_fc: from tools.web_tools import _get_firecrawl_client + result = _get_firecrawl_client() + mock_fc.assert_called_once_with(api_url="http://localhost:3002") + assert result is mock_fc.return_value - _get_firecrawl_client() - mock_firecrawl.assert_called_once_with( - api_url="http://localhost:3002" - ) - - def test_no_key_no_url_raises(self): - """Neither key nor URL set should raise a clear error.""" - self._clear_firecrawl_env() + def test_no_config_raises_with_helpful_message(self): + """Neither key nor URL → ValueError with guidance.""" with patch("tools.web_tools.Firecrawl"): from tools.web_tools import _get_firecrawl_client - with pytest.raises(ValueError, match="FIRECRAWL_API_KEY"): _get_firecrawl_client() + + # ── Singleton caching ──────────────────────────────────────────── + + def test_singleton_returns_same_instance(self): + """Second call returns cached client without re-constructing.""" + with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + with patch("tools.web_tools.Firecrawl") as mock_fc: + from tools.web_tools import _get_firecrawl_client + client1 = _get_firecrawl_client() + client2 = _get_firecrawl_client() + assert client1 is client2 + mock_fc.assert_called_once() # constructed only once + + def test_constructor_failure_allows_retry(self): + """If Firecrawl() raises, next call should retry (not return None).""" + import tools.web_tools + with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}): + with patch("tools.web_tools.Firecrawl") as mock_fc: + mock_fc.side_effect = [RuntimeError("init failed"), MagicMock()] + from tools.web_tools import _get_firecrawl_client + + with pytest.raises(RuntimeError): + _get_firecrawl_client() + + # Client stayed None, so retry should work + assert tools.web_tools._firecrawl_client is None + result = _get_firecrawl_client() + assert result is not None + + # ── Edge cases ─────────────────────────────────────────────────── + + def test_empty_string_key_treated_as_absent(self): + """FIRECRAWL_API_KEY='' should not be passed as api_key.""" + with patch.dict(os.environ, { + "FIRECRAWL_API_KEY": "", + "FIRECRAWL_API_URL": "http://localhost:3002", + }): + with patch("tools.web_tools.Firecrawl") as mock_fc: + from tools.web_tools import _get_firecrawl_client + _get_firecrawl_client() + # Empty string is falsy, so only api_url should be passed + mock_fc.assert_called_once_with(api_url="http://localhost:3002") + + def test_empty_string_key_no_url_raises(self): + """FIRECRAWL_API_KEY='' with no URL → should raise.""" + with patch.dict(os.environ, {"FIRECRAWL_API_KEY": ""}): + with patch("tools.web_tools.Firecrawl"): + from tools.web_tools import _get_firecrawl_client + with pytest.raises(ValueError): + _get_firecrawl_client()