Covers google_web_search (missing API key, correct params, return value, empty query) and get_llm_client (client instantiation, completion method, text content, independence). Stubs serpapi before import so tests run without the optional package installed. Fixes #1294 Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
125 lines
4.9 KiB
Python
125 lines
4.9 KiB
Python
"""Unit tests for timmy/research_tools.py."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import sys
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
# serpapi is an optional dependency not installed in the test environment.
|
|
# Stub it before importing the module under test.
|
|
if "serpapi" not in sys.modules:
|
|
sys.modules["serpapi"] = MagicMock()
|
|
|
|
from timmy.research_tools import get_llm_client, google_web_search # noqa: E402
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# google_web_search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGoogleWebSearch:
|
|
@pytest.mark.asyncio
|
|
async def test_missing_api_key_returns_empty_string(self):
|
|
"""Returns '' and logs a warning when SERPAPI_API_KEY is absent."""
|
|
env = {k: v for k, v in os.environ.items() if k != "SERPAPI_API_KEY"}
|
|
with patch.dict(os.environ, env, clear=True):
|
|
result = await google_web_search("python tutorial")
|
|
assert result == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_calls_google_search_with_correct_params(self):
|
|
"""GoogleSearch is constructed with query and api_key from environ."""
|
|
mock_search_instance = MagicMock()
|
|
mock_search_instance.get_dict.return_value = {"organic_results": [{"title": "Hello"}]}
|
|
mock_search_cls = MagicMock(return_value=mock_search_instance)
|
|
|
|
with patch.dict(os.environ, {"SERPAPI_API_KEY": "test-key-123"}):
|
|
with patch("timmy.research_tools.GoogleSearch", mock_search_cls):
|
|
result = await google_web_search("python tutorial")
|
|
|
|
mock_search_cls.assert_called_once_with(
|
|
{"q": "python tutorial", "api_key": "test-key-123"}
|
|
)
|
|
assert "Hello" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_returns_stringified_results(self):
|
|
"""Return value is str() of whatever get_dict() returns."""
|
|
fake_dict = {"organic_results": [{"title": "Foo", "link": "https://example.com"}]}
|
|
mock_search_instance = MagicMock()
|
|
mock_search_instance.get_dict.return_value = fake_dict
|
|
mock_search_cls = MagicMock(return_value=mock_search_instance)
|
|
|
|
with patch.dict(os.environ, {"SERPAPI_API_KEY": "key"}):
|
|
with patch("timmy.research_tools.GoogleSearch", mock_search_cls):
|
|
result = await google_web_search("foo")
|
|
|
|
assert result == str(fake_dict)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_empty_query_still_calls_search(self):
|
|
"""An empty query is forwarded to GoogleSearch without short-circuiting."""
|
|
mock_search_instance = MagicMock()
|
|
mock_search_instance.get_dict.return_value = {}
|
|
mock_search_cls = MagicMock(return_value=mock_search_instance)
|
|
|
|
with patch.dict(os.environ, {"SERPAPI_API_KEY": "key"}):
|
|
with patch("timmy.research_tools.GoogleSearch", mock_search_cls):
|
|
result = await google_web_search("")
|
|
|
|
mock_search_cls.assert_called_once()
|
|
assert result == str({})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# get_llm_client
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGetLlmClient:
|
|
def test_returns_a_client_object(self):
|
|
"""get_llm_client() returns a non-None object."""
|
|
client = get_llm_client()
|
|
assert client is not None
|
|
|
|
def test_client_has_completion_method(self):
|
|
"""The returned client exposes a callable completion attribute."""
|
|
client = get_llm_client()
|
|
assert callable(getattr(client, "completion", None))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_returns_object_with_text(self):
|
|
"""completion() returns an object whose .text is a non-empty string."""
|
|
client = get_llm_client()
|
|
result = await client.completion("What is Python?", max_tokens=100)
|
|
assert hasattr(result, "text")
|
|
assert isinstance(result.text, str)
|
|
assert len(result.text) > 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_completion_text_contains_prompt(self):
|
|
"""The stub weaves the prompt into the returned text."""
|
|
client = get_llm_client()
|
|
prompt = "Tell me about asyncio"
|
|
result = await client.completion(prompt, max_tokens=50)
|
|
assert prompt in result.text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_calls_return_independent_objects(self):
|
|
"""Each call to completion() returns a fresh object."""
|
|
client = get_llm_client()
|
|
r1 = await client.completion("prompt one", max_tokens=10)
|
|
r2 = await client.completion("prompt two", max_tokens=10)
|
|
assert r1 is not r2
|
|
assert r1.text != r2.text
|
|
|
|
def test_multiple_calls_return_independent_clients(self):
|
|
"""Each call to get_llm_client() returns a distinct instance."""
|
|
c1 = get_llm_client()
|
|
c2 = get_llm_client()
|
|
assert c1 is not c2
|