Add comprehensive test coverage for: - cron/jobs.py: schedule parsing, job CRUD, due-job detection (34 tests) - tools/memory_tool.py: security scanning, MemoryStore ops, dispatcher (32 tests) - toolsets.py: resolution, validation, composition, cycle detection (19 tests) - tools/file_operations.py: write deny list, result dataclasses, helpers (37 tests) - agent/prompt_builder.py: context scanning, truncation, skills index (24 tests) - agent/model_metadata.py: token estimation, context lengths (16 tests) - hermes_state.py: SessionDB SQLite CRUD, FTS5 search, export, prune (28 tests) Total: 210 new tests, all passing (380 total suite).
157 lines
5.4 KiB
Python
157 lines
5.4 KiB
Python
"""Tests for agent/model_metadata.py — token estimation and context lengths."""
|
|
|
|
import pytest
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
from agent.model_metadata import (
|
|
DEFAULT_CONTEXT_LENGTHS,
|
|
estimate_tokens_rough,
|
|
estimate_messages_tokens_rough,
|
|
get_model_context_length,
|
|
fetch_model_metadata,
|
|
_MODEL_CACHE_TTL,
|
|
)
|
|
|
|
|
|
# =========================================================================
|
|
# Token estimation
|
|
# =========================================================================
|
|
|
|
class TestEstimateTokensRough:
|
|
def test_empty_string(self):
|
|
assert estimate_tokens_rough("") == 0
|
|
|
|
def test_none_returns_zero(self):
|
|
assert estimate_tokens_rough(None) == 0
|
|
|
|
def test_known_length(self):
|
|
# 400 chars / 4 = 100 tokens
|
|
text = "a" * 400
|
|
assert estimate_tokens_rough(text) == 100
|
|
|
|
def test_short_text(self):
|
|
# "hello" = 5 chars -> 5 // 4 = 1
|
|
assert estimate_tokens_rough("hello") == 1
|
|
|
|
def test_proportional(self):
|
|
short = estimate_tokens_rough("hello world")
|
|
long = estimate_tokens_rough("hello world " * 100)
|
|
assert long > short
|
|
|
|
|
|
class TestEstimateMessagesTokensRough:
|
|
def test_empty_list(self):
|
|
assert estimate_messages_tokens_rough([]) == 0
|
|
|
|
def test_single_message(self):
|
|
msgs = [{"role": "user", "content": "a" * 400}]
|
|
result = estimate_messages_tokens_rough(msgs)
|
|
assert result > 0
|
|
|
|
def test_multiple_messages(self):
|
|
msgs = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there, how can I help?"},
|
|
]
|
|
result = estimate_messages_tokens_rough(msgs)
|
|
assert result > 0
|
|
|
|
|
|
# =========================================================================
|
|
# Default context lengths
|
|
# =========================================================================
|
|
|
|
class TestDefaultContextLengths:
|
|
def test_claude_models_200k(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "claude" in key:
|
|
assert value == 200000, f"{key} should be 200000"
|
|
|
|
def test_gpt4_models_128k(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gpt-4" in key:
|
|
assert value == 128000, f"{key} should be 128000"
|
|
|
|
def test_gemini_models_1m(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
if "gemini" in key:
|
|
assert value == 1048576, f"{key} should be 1048576"
|
|
|
|
def test_all_values_positive(self):
|
|
for key, value in DEFAULT_CONTEXT_LENGTHS.items():
|
|
assert value > 0, f"{key} has non-positive context length"
|
|
|
|
|
|
# =========================================================================
|
|
# get_model_context_length (with mocked API)
|
|
# =========================================================================
|
|
|
|
class TestGetModelContextLength:
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_known_model_from_api(self, mock_fetch):
|
|
mock_fetch.return_value = {
|
|
"test/model": {"context_length": 32000}
|
|
}
|
|
assert get_model_context_length("test/model") == 32000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_fallback_to_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {} # API returns nothing
|
|
result = get_model_context_length("anthropic/claude-sonnet-4")
|
|
assert result == 200000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_unknown_model_returns_128k(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
result = get_model_context_length("unknown/never-heard-of-this")
|
|
assert result == 128000
|
|
|
|
@patch("agent.model_metadata.fetch_model_metadata")
|
|
def test_partial_match_in_defaults(self, mock_fetch):
|
|
mock_fetch.return_value = {}
|
|
# "gpt-4o" is a substring match for "openai/gpt-4o"
|
|
result = get_model_context_length("openai/gpt-4o")
|
|
assert result == 128000
|
|
|
|
|
|
# =========================================================================
|
|
# fetch_model_metadata (cache behavior)
|
|
# =========================================================================
|
|
|
|
class TestFetchModelMetadata:
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_caches_result(self, mock_get):
|
|
import agent.model_metadata as mm
|
|
# Reset cache
|
|
mm._model_metadata_cache = {}
|
|
mm._model_metadata_cache_time = 0
|
|
|
|
mock_response = MagicMock()
|
|
mock_response.json.return_value = {
|
|
"data": [
|
|
{"id": "test/model", "context_length": 99999, "name": "Test Model"}
|
|
]
|
|
}
|
|
mock_response.raise_for_status = MagicMock()
|
|
mock_get.return_value = mock_response
|
|
|
|
# First call fetches
|
|
result1 = fetch_model_metadata(force_refresh=True)
|
|
assert "test/model" in result1
|
|
assert mock_get.call_count == 1
|
|
|
|
# Second call uses cache
|
|
result2 = fetch_model_metadata()
|
|
assert "test/model" in result2
|
|
assert mock_get.call_count == 1 # Not called again
|
|
|
|
@patch("agent.model_metadata.requests.get")
|
|
def test_api_failure_returns_empty(self, mock_get):
|
|
import agent.model_metadata as mm
|
|
mm._model_metadata_cache = {}
|
|
mm._model_metadata_cache_time = 0
|
|
|
|
mock_get.side_effect = Exception("Network error")
|
|
result = fetch_model_metadata(force_refresh=True)
|
|
assert result == {}
|