diff --git a/tests/timmy/test_research_tools.py b/tests/timmy/test_research_tools.py new file mode 100644 index 00000000..057b60bd --- /dev/null +++ b/tests/timmy/test_research_tools.py @@ -0,0 +1,124 @@ +"""Unit tests for timmy/research_tools.py.""" + +from __future__ import annotations + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +# serpapi is an optional dependency not installed in the test environment. +# Stub it before importing the module under test. +if "serpapi" not in sys.modules: + sys.modules["serpapi"] = MagicMock() + +from timmy.research_tools import get_llm_client, google_web_search # noqa: E402 + + +# --------------------------------------------------------------------------- +# google_web_search +# --------------------------------------------------------------------------- + + +class TestGoogleWebSearch: + @pytest.mark.asyncio + async def test_missing_api_key_returns_empty_string(self): + """Returns '' and logs a warning when SERPAPI_API_KEY is absent.""" + env = {k: v for k, v in os.environ.items() if k != "SERPAPI_API_KEY"} + with patch.dict(os.environ, env, clear=True): + result = await google_web_search("python tutorial") + assert result == "" + + @pytest.mark.asyncio + async def test_calls_google_search_with_correct_params(self): + """GoogleSearch is constructed with query and api_key from environ.""" + mock_search_instance = MagicMock() + mock_search_instance.get_dict.return_value = {"organic_results": [{"title": "Hello"}]} + mock_search_cls = MagicMock(return_value=mock_search_instance) + + with patch.dict(os.environ, {"SERPAPI_API_KEY": "test-key-123"}): + with patch("timmy.research_tools.GoogleSearch", mock_search_cls): + result = await google_web_search("python tutorial") + + mock_search_cls.assert_called_once_with( + {"q": "python tutorial", "api_key": "test-key-123"} + ) + assert "Hello" in result + + @pytest.mark.asyncio + async def test_returns_stringified_results(self): + """Return value is str() of whatever get_dict() returns.""" + fake_dict = {"organic_results": [{"title": "Foo", "link": "https://example.com"}]} + mock_search_instance = MagicMock() + mock_search_instance.get_dict.return_value = fake_dict + mock_search_cls = MagicMock(return_value=mock_search_instance) + + with patch.dict(os.environ, {"SERPAPI_API_KEY": "key"}): + with patch("timmy.research_tools.GoogleSearch", mock_search_cls): + result = await google_web_search("foo") + + assert result == str(fake_dict) + + @pytest.mark.asyncio + async def test_empty_query_still_calls_search(self): + """An empty query is forwarded to GoogleSearch without short-circuiting.""" + mock_search_instance = MagicMock() + mock_search_instance.get_dict.return_value = {} + mock_search_cls = MagicMock(return_value=mock_search_instance) + + with patch.dict(os.environ, {"SERPAPI_API_KEY": "key"}): + with patch("timmy.research_tools.GoogleSearch", mock_search_cls): + result = await google_web_search("") + + mock_search_cls.assert_called_once() + assert result == str({}) + + +# --------------------------------------------------------------------------- +# get_llm_client +# --------------------------------------------------------------------------- + + +class TestGetLlmClient: + def test_returns_a_client_object(self): + """get_llm_client() returns a non-None object.""" + client = get_llm_client() + assert client is not None + + def test_client_has_completion_method(self): + """The returned client exposes a callable completion attribute.""" + client = get_llm_client() + assert callable(getattr(client, "completion", None)) + + @pytest.mark.asyncio + async def test_completion_returns_object_with_text(self): + """completion() returns an object whose .text is a non-empty string.""" + client = get_llm_client() + result = await client.completion("What is Python?", max_tokens=100) + assert hasattr(result, "text") + assert isinstance(result.text, str) + assert len(result.text) > 0 + + @pytest.mark.asyncio + async def test_completion_text_contains_prompt(self): + """The stub weaves the prompt into the returned text.""" + client = get_llm_client() + prompt = "Tell me about asyncio" + result = await client.completion(prompt, max_tokens=50) + assert prompt in result.text + + @pytest.mark.asyncio + async def test_multiple_calls_return_independent_objects(self): + """Each call to completion() returns a fresh object.""" + client = get_llm_client() + r1 = await client.completion("prompt one", max_tokens=10) + r2 = await client.completion("prompt two", max_tokens=10) + assert r1 is not r2 + assert r1.text != r2.text + + def test_multiple_calls_return_independent_clients(self): + """Each call to get_llm_client() returns a distinct instance.""" + c1 = get_llm_client() + c2 = get_llm_client() + assert c1 is not c2