test: comprehensive tests for model metadata + firecrawl config
model_metadata tests (61 tests, was 39):
- Token estimation: concrete value assertions, unicode, tool_call messages,
vision multimodal content, additive verification
- Context length resolution: cache-over-API priority, no-base_url skips cache,
missing context_length key in API response
- API metadata fetch: canonical_slug aliasing, TTL expiry with time mock,
stale cache fallback on API failure, malformed JSON resilience
- Probe tiers: above-max returns 2M, zero returns None
- Error parsing: Anthropic format ('X > Y maximum'), LM Studio, empty string,
unreasonably large numbers — also fixed parser to handle Anthropic format
- Cache: corruption resilience (garbage YAML, wrong structure), value updates,
special chars in model names
Firecrawl config tests (8 tests, was 4):
- Singleton caching (core purpose — verified constructor called once)
- Constructor failure recovery (retry after exception)
- Return value actually asserted (not just constructor args)
- Empty string env vars treated as absent
- Proper setup/teardown for env var isolation
This commit is contained in:
@@ -158,6 +158,8 @@ def parse_context_limit_from_error(error_msg: str) -> Optional[int]:
|
||||
r'(?:max(?:imum)?|limit)\s*(?:context\s*)?(?:length|size|window)?\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'context\s*(?:length|size|window)\s*(?:is|of|:)?\s*(\d{4,})',
|
||||
r'(\d{4,})\s*(?:token)?\s*(?:context|limit)',
|
||||
r'>\s*(\d{4,})\s*(?:max|limit|token)', # "250000 tokens > 200000 maximum"
|
||||
r'(\d{4,})\s*(?:max(?:imum)?)\b', # "200000 maximum"
|
||||
]
|
||||
for pattern in patterns:
|
||||
match = re.search(pattern, error_lower)
|
||||
|
||||
@@ -1,10 +1,22 @@
|
||||
"""Tests for agent/model_metadata.py — token estimation and context lengths."""
|
||||
"""Tests for agent/model_metadata.py — token estimation, context lengths,
|
||||
probing, caching, and error parsing.
|
||||
|
||||
Coverage levels:
|
||||
Token estimation — concrete value assertions, edge cases
|
||||
Context length lookup — resolution order, fuzzy match, cache priority
|
||||
API metadata fetch — caching, TTL, canonical slugs, stale fallback
|
||||
Probe tiers — descending, boundaries, extreme inputs
|
||||
Error parsing — OpenAI, Ollama, Anthropic, edge cases
|
||||
Persistent cache — save/load, corruption, update, provider isolation
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from agent.model_metadata import (
|
||||
@@ -34,12 +46,9 @@ class TestEstimateTokensRough:
|
||||
assert estimate_tokens_rough(None) == 0
|
||||
|
||||
def test_known_length(self):
|
||||
# 400 chars / 4 = 100 tokens
|
||||
text = "a" * 400
|
||||
assert estimate_tokens_rough(text) == 100
|
||||
assert estimate_tokens_rough("a" * 400) == 100
|
||||
|
||||
def test_short_text(self):
|
||||
# "hello" = 5 chars -> 5 // 4 = 1
|
||||
assert estimate_tokens_rough("hello") == 1
|
||||
|
||||
def test_proportional(self):
|
||||
@@ -47,23 +56,48 @@ class TestEstimateTokensRough:
|
||||
long = estimate_tokens_rough("hello world " * 100)
|
||||
assert long > short
|
||||
|
||||
def test_unicode_multibyte(self):
|
||||
"""Unicode chars are still 1 Python char each — 4 chars/token holds."""
|
||||
text = "你好世界" # 4 CJK characters
|
||||
assert estimate_tokens_rough(text) == 1
|
||||
|
||||
|
||||
class TestEstimateMessagesTokensRough:
|
||||
def test_empty_list(self):
|
||||
assert estimate_messages_tokens_rough([]) == 0
|
||||
|
||||
def test_single_message(self):
|
||||
msgs = [{"role": "user", "content": "a" * 400}]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
assert result > 0
|
||||
def test_single_message_concrete_value(self):
|
||||
"""Verify against known str(msg) length."""
|
||||
msg = {"role": "user", "content": "a" * 400}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
expected = len(str(msg)) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_multiple_messages(self):
|
||||
def test_multiple_messages_additive(self):
|
||||
msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there, how can I help?"},
|
||||
]
|
||||
result = estimate_messages_tokens_rough(msgs)
|
||||
expected = sum(len(str(m)) for m in msgs) // 4
|
||||
assert result == expected
|
||||
|
||||
def test_tool_call_message(self):
|
||||
"""Tool call messages with no 'content' key still contribute tokens."""
|
||||
msg = {"role": "assistant", "content": None,
|
||||
"tool_calls": [{"id": "1", "function": {"name": "terminal", "arguments": "{}"}}]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result > 0
|
||||
assert result == len(str(msg)) // 4
|
||||
|
||||
def test_message_with_list_content(self):
|
||||
"""Vision messages with multimodal content arrays."""
|
||||
msg = {"role": "user", "content": [
|
||||
{"type": "text", "text": "describe"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,AAAA"}}
|
||||
]}
|
||||
result = estimate_messages_tokens_rough([msg])
|
||||
assert result == len(str(msg)) // 4
|
||||
|
||||
|
||||
# =========================================================================
|
||||
@@ -90,9 +124,12 @@ class TestDefaultContextLengths:
|
||||
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
||||
assert value > 0, f"{key} has non-positive context length"
|
||||
|
||||
def test_dict_is_not_empty(self):
|
||||
assert len(DEFAULT_CONTEXT_LENGTHS) >= 10
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# get_model_context_length (with mocked API)
|
||||
# get_model_context_length — resolution order
|
||||
# =========================================================================
|
||||
|
||||
class TestGetModelContextLength:
|
||||
@@ -105,62 +142,146 @@ class TestGetModelContextLength:
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_fallback_to_defaults(self, mock_fetch):
|
||||
mock_fetch.return_value = {} # API returns nothing
|
||||
result = get_model_context_length("anthropic/claude-sonnet-4")
|
||||
assert result == 200000
|
||||
mock_fetch.return_value = {}
|
||||
assert get_model_context_length("anthropic/claude-sonnet-4") == 200000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_unknown_model_returns_first_probe_tier(self, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
result = get_model_context_length("unknown/never-heard-of-this")
|
||||
assert result == CONTEXT_PROBE_TIERS[0] # 2M — will be narrowed on context error
|
||||
assert get_model_context_length("unknown/never-heard-of-this") == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_partial_match_in_defaults(self, mock_fetch):
|
||||
mock_fetch.return_value = {}
|
||||
# "gpt-4o" is a substring match for "openai/gpt-4o"
|
||||
result = get_model_context_length("openai/gpt-4o")
|
||||
assert result == 128000
|
||||
assert get_model_context_length("openai/gpt-4o") == 128000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_api_missing_context_length_key(self, mock_fetch):
|
||||
"""Model in API but without context_length → defaults to 128000."""
|
||||
mock_fetch.return_value = {"test/model": {"name": "Test"}}
|
||||
assert get_model_context_length("test/model") == 128000
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_cache_takes_priority_over_api(self, mock_fetch, tmp_path):
|
||||
"""Persistent cache should be checked BEFORE API metadata."""
|
||||
mock_fetch.return_value = {"my/model": {"context_length": 999999}}
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("my/model", "http://local", 32768)
|
||||
result = get_model_context_length("my/model", base_url="http://local")
|
||||
assert result == 32768 # cache wins over API's 999999
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_no_base_url_skips_cache(self, mock_fetch, tmp_path):
|
||||
"""Without base_url, cache lookup is skipped."""
|
||||
mock_fetch.return_value = {}
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("custom/model", "http://local", 32768)
|
||||
# No base_url → cache skipped → falls to probe tier
|
||||
result = get_model_context_length("custom/model")
|
||||
assert result == CONTEXT_PROBE_TIERS[0]
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# fetch_model_metadata (cache behavior)
|
||||
# fetch_model_metadata — caching, TTL, slugs, failures
|
||||
# =========================================================================
|
||||
|
||||
class TestFetchModelMetadata:
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_caches_result(self, mock_get):
|
||||
def _reset_cache(self):
|
||||
import agent.model_metadata as mm
|
||||
# Reset cache
|
||||
mm._model_metadata_cache = {}
|
||||
mm._model_metadata_cache_time = 0
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_caches_result(self, mock_get):
|
||||
self._reset_cache()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [
|
||||
{"id": "test/model", "context_length": 99999, "name": "Test Model"}
|
||||
]
|
||||
"data": [{"id": "test/model", "context_length": 99999, "name": "Test"}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# First call fetches
|
||||
result1 = fetch_model_metadata(force_refresh=True)
|
||||
assert "test/model" in result1
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Second call uses cache
|
||||
result2 = fetch_model_metadata()
|
||||
assert "test/model" in result2
|
||||
assert mock_get.call_count == 1 # Not called again
|
||||
assert mock_get.call_count == 1 # cached
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_api_failure_returns_empty(self, mock_get):
|
||||
def test_api_failure_returns_empty_on_cold_cache(self, mock_get):
|
||||
self._reset_cache()
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
assert result == {}
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_api_failure_returns_stale_cache(self, mock_get):
|
||||
"""On API failure with existing cache, stale data is returned."""
|
||||
import agent.model_metadata as mm
|
||||
mm._model_metadata_cache = {}
|
||||
mm._model_metadata_cache_time = 0
|
||||
mm._model_metadata_cache = {"old/model": {"context_length": 50000}}
|
||||
mm._model_metadata_cache_time = 0 # expired
|
||||
|
||||
mock_get.side_effect = Exception("Network error")
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
assert "old/model" in result
|
||||
assert result["old/model"]["context_length"] == 50000
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_canonical_slug_aliasing(self, mock_get):
|
||||
"""Models with canonical_slug get indexed under both IDs."""
|
||||
self._reset_cache()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [{
|
||||
"id": "anthropic/claude-3.5-sonnet:beta",
|
||||
"canonical_slug": "anthropic/claude-3.5-sonnet",
|
||||
"context_length": 200000,
|
||||
"name": "Claude 3.5 Sonnet"
|
||||
}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
# Both the original ID and canonical slug should work
|
||||
assert "anthropic/claude-3.5-sonnet:beta" in result
|
||||
assert "anthropic/claude-3.5-sonnet" in result
|
||||
assert result["anthropic/claude-3.5-sonnet"]["context_length"] == 200000
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_ttl_expiry_triggers_refetch(self, mock_get):
|
||||
"""Cache expires after _MODEL_CACHE_TTL seconds."""
|
||||
import agent.model_metadata as mm
|
||||
self._reset_cache()
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "m1", "context_length": 1000, "name": "M1"}]
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
fetch_model_metadata(force_refresh=True)
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
# Simulate TTL expiry
|
||||
mm._model_metadata_cache_time = time.time() - _MODEL_CACHE_TTL - 1
|
||||
fetch_model_metadata()
|
||||
assert mock_get.call_count == 2 # refetched
|
||||
|
||||
@patch("agent.model_metadata.requests.get")
|
||||
def test_malformed_json_no_data_key(self, mock_get):
|
||||
"""API returns JSON without 'data' key — empty cache, no crash."""
|
||||
self._reset_cache()
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"error": "something"}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = fetch_model_metadata(force_refresh=True)
|
||||
assert result == {}
|
||||
|
||||
@@ -198,9 +319,15 @@ class TestGetNextProbeTier:
|
||||
assert get_next_probe_tier(16_000) is None
|
||||
|
||||
def test_from_arbitrary_value(self):
|
||||
# 300K is between 512K and 200K, should return 200K
|
||||
assert get_next_probe_tier(300_000) == 200_000
|
||||
|
||||
def test_above_max_tier(self):
|
||||
"""Value above 2M should return 2M."""
|
||||
assert get_next_probe_tier(5_000_000) == 2_000_000
|
||||
|
||||
def test_zero_returns_none(self):
|
||||
assert get_next_probe_tier(0) is None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Error message parsing
|
||||
@@ -220,17 +347,35 @@ class TestParseContextLimitFromError:
|
||||
assert parse_context_limit_from_error(msg) == 65536
|
||||
|
||||
def test_no_limit_in_message(self):
|
||||
msg = "Something went wrong with the API"
|
||||
assert parse_context_limit_from_error(msg) is None
|
||||
assert parse_context_limit_from_error("Something went wrong with the API") is None
|
||||
|
||||
def test_unreasonable_number_rejected(self):
|
||||
msg = "context length is 42 tokens" # too small
|
||||
assert parse_context_limit_from_error(msg) is None
|
||||
def test_unreasonable_small_number_rejected(self):
|
||||
assert parse_context_limit_from_error("context length is 42 tokens") is None
|
||||
|
||||
def test_ollama_format(self):
|
||||
msg = "Context size has been exceeded. Maximum context size is 32768"
|
||||
assert parse_context_limit_from_error(msg) == 32768
|
||||
|
||||
def test_anthropic_format(self):
|
||||
msg = "prompt is too long: 250000 tokens > 200000 maximum"
|
||||
# Should extract 200000 (the limit), not 250000 (the input size)
|
||||
assert parse_context_limit_from_error(msg) == 200000
|
||||
|
||||
def test_lmstudio_format(self):
|
||||
msg = "Error: context window of 4096 tokens exceeded"
|
||||
assert parse_context_limit_from_error(msg) == 4096
|
||||
|
||||
def test_completely_unrelated_error(self):
|
||||
assert parse_context_limit_from_error("Invalid API key") is None
|
||||
|
||||
def test_empty_string(self):
|
||||
assert parse_context_limit_from_error("") is None
|
||||
|
||||
def test_number_outside_reasonable_range(self):
|
||||
"""Very large number (>10M) should be rejected."""
|
||||
msg = "maximum context length is 99999999999"
|
||||
assert parse_context_limit_from_error(msg) is None
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Persistent context length cache
|
||||
@@ -238,11 +383,10 @@ class TestParseContextLimitFromError:
|
||||
|
||||
class TestContextLengthCache:
|
||||
def test_save_and_load(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("test/model", "http://localhost:8080/v1", 32768)
|
||||
result = get_cached_context_length("test/model", "http://localhost:8080/v1")
|
||||
assert result == 32768
|
||||
assert get_cached_context_length("test/model", "http://localhost:8080/v1") == 32768
|
||||
|
||||
def test_missing_cache_returns_none(self, tmp_path):
|
||||
cache_file = tmp_path / "nonexistent.yaml"
|
||||
@@ -250,7 +394,7 @@ class TestContextLengthCache:
|
||||
assert get_cached_context_length("test/model", "http://x") is None
|
||||
|
||||
def test_multiple_models_cached(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("model-a", "http://a", 64000)
|
||||
save_context_length("model-b", "http://b", 128000)
|
||||
@@ -258,7 +402,7 @@ class TestContextLengthCache:
|
||||
assert get_cached_context_length("model-b", "http://b") == 128000
|
||||
|
||||
def test_same_model_different_providers(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("llama-3", "http://local:8080", 32768)
|
||||
save_context_length("llama-3", "https://openrouter.ai/api/v1", 131072)
|
||||
@@ -266,20 +410,49 @@ class TestContextLengthCache:
|
||||
assert get_cached_context_length("llama-3", "https://openrouter.ai/api/v1") == 131072
|
||||
|
||||
def test_idempotent_save(self, tmp_path):
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("model", "http://x", 32768)
|
||||
save_context_length("model", "http://x", 32768) # same value
|
||||
save_context_length("model", "http://x", 32768)
|
||||
with open(cache_file) as f:
|
||||
data = yaml.safe_load(f)
|
||||
assert len(data["context_lengths"]) == 1
|
||||
|
||||
def test_update_existing_value(self, tmp_path):
|
||||
"""Saving a different value for the same key overwrites it."""
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("model", "http://x", 128000)
|
||||
save_context_length("model", "http://x", 64000)
|
||||
assert get_cached_context_length("model", "http://x") == 64000
|
||||
|
||||
def test_corrupted_yaml_returns_empty(self, tmp_path):
|
||||
"""Corrupted cache file is handled gracefully."""
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
cache_file.write_text("{{{{not valid yaml: [[[")
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
assert get_cached_context_length("model", "http://x") is None
|
||||
|
||||
def test_wrong_structure_returns_none(self, tmp_path):
|
||||
"""YAML that loads but has wrong structure."""
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
cache_file.write_text("just_a_string\n")
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
assert get_cached_context_length("model", "http://x") is None
|
||||
|
||||
@patch("agent.model_metadata.fetch_model_metadata")
|
||||
def test_cached_value_takes_priority(self, mock_fetch, tmp_path):
|
||||
"""Cached context length should be used before API or defaults."""
|
||||
mock_fetch.return_value = {}
|
||||
cache_file = tmp_path / "context_length_cache.yaml"
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length("unknown/model", "http://local", 65536)
|
||||
result = get_model_context_length("unknown/model", base_url="http://local")
|
||||
assert result == 65536
|
||||
assert get_model_context_length("unknown/model", base_url="http://local") == 65536
|
||||
|
||||
def test_special_chars_in_model_name(self, tmp_path):
|
||||
"""Model names with colons, slashes, etc. don't break the cache."""
|
||||
cache_file = tmp_path / "cache.yaml"
|
||||
model = "anthropic/claude-3.5-sonnet:beta"
|
||||
url = "https://api.example.com/v1"
|
||||
with patch("agent.model_metadata._get_context_cache_path", return_value=cache_file):
|
||||
save_context_length(model, url, 200000)
|
||||
assert get_cached_context_length(model, url) == 200000
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
"""Tests for Firecrawl client configuration."""
|
||||
"""Tests for Firecrawl client configuration and singleton behavior.
|
||||
|
||||
Coverage:
|
||||
_get_firecrawl_client() — configuration matrix, singleton caching,
|
||||
constructor failure recovery, return value verification, edge cases.
|
||||
"""
|
||||
|
||||
import os
|
||||
import pytest
|
||||
@@ -6,69 +11,109 @@ from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestFirecrawlClientConfig:
|
||||
"""Test suite for Firecrawl client initialization with API URL support."""
|
||||
"""Test suite for Firecrawl client initialization."""
|
||||
|
||||
def teardown_method(self):
|
||||
"""Reset client between tests."""
|
||||
def setup_method(self):
|
||||
"""Reset client and env vars before each test."""
|
||||
import tools.web_tools
|
||||
|
||||
tools.web_tools._firecrawl_client = None
|
||||
|
||||
def _clear_firecrawl_env(self):
|
||||
"""Remove Firecrawl env vars so tests start clean."""
|
||||
for key in ("FIRECRAWL_API_KEY", "FIRECRAWL_API_URL"):
|
||||
os.environ.pop(key, None)
|
||||
|
||||
def test_client_with_api_key_only(self):
|
||||
"""Test client initialization with only API key (cloud mode)."""
|
||||
self._clear_firecrawl_env()
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "test-key"}, clear=False):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_firecrawl:
|
||||
def teardown_method(self):
|
||||
"""Reset client after each test."""
|
||||
import tools.web_tools
|
||||
tools.web_tools._firecrawl_client = None
|
||||
for key in ("FIRECRAWL_API_KEY", "FIRECRAWL_API_URL"):
|
||||
os.environ.pop(key, None)
|
||||
|
||||
# ── Configuration matrix ─────────────────────────────────────────
|
||||
|
||||
def test_cloud_mode_key_only(self):
|
||||
"""API key without URL → cloud Firecrawl."""
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
result = _get_firecrawl_client()
|
||||
mock_fc.assert_called_once_with(api_key="fc-test")
|
||||
assert result is mock_fc.return_value
|
||||
|
||||
_get_firecrawl_client()
|
||||
mock_firecrawl.assert_called_once_with(api_key="test-key")
|
||||
|
||||
def test_client_with_api_key_and_url(self):
|
||||
"""Test client initialization with API key and custom URL."""
|
||||
self._clear_firecrawl_env()
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"FIRECRAWL_API_KEY": "test-key",
|
||||
"FIRECRAWL_API_URL": "http://localhost:3002",
|
||||
},
|
||||
clear=False,
|
||||
):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_firecrawl:
|
||||
def test_self_hosted_with_key(self):
|
||||
"""Both key + URL → self-hosted with auth."""
|
||||
with patch.dict(os.environ, {
|
||||
"FIRECRAWL_API_KEY": "fc-test",
|
||||
"FIRECRAWL_API_URL": "http://localhost:3002",
|
||||
}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
|
||||
_get_firecrawl_client()
|
||||
mock_firecrawl.assert_called_once_with(
|
||||
api_key="test-key", api_url="http://localhost:3002"
|
||||
result = _get_firecrawl_client()
|
||||
mock_fc.assert_called_once_with(
|
||||
api_key="fc-test", api_url="http://localhost:3002"
|
||||
)
|
||||
assert result is mock_fc.return_value
|
||||
|
||||
def test_client_with_url_only_no_key(self):
|
||||
"""Self-hosted mode: URL without API key should work."""
|
||||
self._clear_firecrawl_env()
|
||||
with patch.dict(
|
||||
os.environ,
|
||||
{"FIRECRAWL_API_URL": "http://localhost:3002"},
|
||||
clear=False,
|
||||
):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_firecrawl:
|
||||
def test_self_hosted_no_key(self):
|
||||
"""URL only, no key → self-hosted without auth."""
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_URL": "http://localhost:3002"}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
result = _get_firecrawl_client()
|
||||
mock_fc.assert_called_once_with(api_url="http://localhost:3002")
|
||||
assert result is mock_fc.return_value
|
||||
|
||||
_get_firecrawl_client()
|
||||
mock_firecrawl.assert_called_once_with(
|
||||
api_url="http://localhost:3002"
|
||||
)
|
||||
|
||||
def test_no_key_no_url_raises(self):
|
||||
"""Neither key nor URL set should raise a clear error."""
|
||||
self._clear_firecrawl_env()
|
||||
def test_no_config_raises_with_helpful_message(self):
|
||||
"""Neither key nor URL → ValueError with guidance."""
|
||||
with patch("tools.web_tools.Firecrawl"):
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
|
||||
with pytest.raises(ValueError, match="FIRECRAWL_API_KEY"):
|
||||
_get_firecrawl_client()
|
||||
|
||||
# ── Singleton caching ────────────────────────────────────────────
|
||||
|
||||
def test_singleton_returns_same_instance(self):
|
||||
"""Second call returns cached client without re-constructing."""
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
client1 = _get_firecrawl_client()
|
||||
client2 = _get_firecrawl_client()
|
||||
assert client1 is client2
|
||||
mock_fc.assert_called_once() # constructed only once
|
||||
|
||||
def test_constructor_failure_allows_retry(self):
|
||||
"""If Firecrawl() raises, next call should retry (not return None)."""
|
||||
import tools.web_tools
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": "fc-test"}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
mock_fc.side_effect = [RuntimeError("init failed"), MagicMock()]
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_get_firecrawl_client()
|
||||
|
||||
# Client stayed None, so retry should work
|
||||
assert tools.web_tools._firecrawl_client is None
|
||||
result = _get_firecrawl_client()
|
||||
assert result is not None
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────
|
||||
|
||||
def test_empty_string_key_treated_as_absent(self):
|
||||
"""FIRECRAWL_API_KEY='' should not be passed as api_key."""
|
||||
with patch.dict(os.environ, {
|
||||
"FIRECRAWL_API_KEY": "",
|
||||
"FIRECRAWL_API_URL": "http://localhost:3002",
|
||||
}):
|
||||
with patch("tools.web_tools.Firecrawl") as mock_fc:
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
_get_firecrawl_client()
|
||||
# Empty string is falsy, so only api_url should be passed
|
||||
mock_fc.assert_called_once_with(api_url="http://localhost:3002")
|
||||
|
||||
def test_empty_string_key_no_url_raises(self):
|
||||
"""FIRECRAWL_API_KEY='' with no URL → should raise."""
|
||||
with patch.dict(os.environ, {"FIRECRAWL_API_KEY": ""}):
|
||||
with patch("tools.web_tools.Firecrawl"):
|
||||
from tools.web_tools import _get_firecrawl_client
|
||||
with pytest.raises(ValueError):
|
||||
_get_firecrawl_client()
|
||||
|
||||
Reference in New Issue
Block a user