Files
the-nexus/tests/test_gemini_harness.py
Claude (Opus 4.6) 4496ff2d80
Some checks failed
Deploy Nexus / deploy (push) Has been cancelled
[claude] Stand up Gemini harness as network worker (#748) (#811)
2026-04-04 01:41:53 +00:00

567 lines
25 KiB
Python

#!/usr/bin/env python3
"""
Gemini Harness Test Suite
Tests for the Gemini 3.1 Pro harness implementing the Hermes/OpenClaw worker pattern.
Usage:
pytest tests/test_gemini_harness.py -v
pytest tests/test_gemini_harness.py -v -k "not live"
RUN_LIVE_TESTS=1 pytest tests/test_gemini_harness.py -v # real API calls
"""
import json
import os
import sys
import time
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from nexus.gemini_harness import (
COST_PER_1M_INPUT,
COST_PER_1M_OUTPUT,
GEMINI_MODEL_PRIMARY,
GEMINI_MODEL_SECONDARY,
GEMINI_MODEL_TERTIARY,
HARNESS_ID,
MODEL_FALLBACK_CHAIN,
ContextCache,
GeminiHarness,
GeminiResponse,
)
# ═══════════════════════════════════════════════════════════════════════════
# FIXTURES
# ═══════════════════════════════════════════════════════════════════════════
@pytest.fixture
def harness():
"""Harness with a fake API key so no real calls are made in unit tests."""
return GeminiHarness(api_key="fake-key-for-testing")
@pytest.fixture
def harness_with_context(harness):
"""Harness with pre-loaded project context."""
harness.set_context("Timmy is sovereign. Gemini is a worker on the network.")
return harness
@pytest.fixture
def mock_ok_response():
"""Mock requests.post that returns a successful Gemini API response."""
mock = MagicMock()
mock.status_code = 200
mock.json.return_value = {
"choices": [{"message": {"content": "Hello from Gemini"}}],
"usage": {"prompt_tokens": 10, "completion_tokens": 5},
}
return mock
@pytest.fixture
def mock_error_response():
"""Mock requests.post that returns a 429 rate-limit error."""
mock = MagicMock()
mock.status_code = 429
mock.text = "Rate limit exceeded"
return mock
# ═══════════════════════════════════════════════════════════════════════════
# GeminiResponse DATA CLASS
# ═══════════════════════════════════════════════════════════════════════════
class TestGeminiResponse:
def test_default_creation(self):
resp = GeminiResponse()
assert resp.text == ""
assert resp.model == ""
assert resp.input_tokens == 0
assert resp.output_tokens == 0
assert resp.latency_ms == 0.0
assert resp.cost_usd == 0.0
assert resp.cached is False
assert resp.error is None
assert resp.timestamp
def test_to_dict_includes_all_fields(self):
resp = GeminiResponse(
text="hi", model="gemini-2.5-pro-preview-03-25", input_tokens=10,
output_tokens=5, latency_ms=120.5, cost_usd=0.000035,
)
d = resp.to_dict()
assert d["text"] == "hi"
assert d["model"] == "gemini-2.5-pro-preview-03-25"
assert d["input_tokens"] == 10
assert d["output_tokens"] == 5
assert d["latency_ms"] == 120.5
assert d["cost_usd"] == 0.000035
assert d["cached"] is False
assert d["error"] is None
assert "timestamp" in d
def test_error_response(self):
resp = GeminiResponse(error="HTTP 429: Rate limit")
assert resp.error == "HTTP 429: Rate limit"
assert resp.text == ""
# ═══════════════════════════════════════════════════════════════════════════
# ContextCache
# ═══════════════════════════════════════════════════════════════════════════
class TestContextCache:
def test_valid_fresh_cache(self):
cache = ContextCache(content="project context", ttl_seconds=3600.0)
assert cache.is_valid()
def test_expired_cache(self):
cache = ContextCache(content="old context", ttl_seconds=0.001)
time.sleep(0.01)
assert not cache.is_valid()
def test_hit_count_increments(self):
cache = ContextCache(content="ctx")
assert cache.hit_count == 0
cache.touch()
cache.touch()
assert cache.hit_count == 2
def test_unique_cache_ids(self):
a = ContextCache()
b = ContextCache()
assert a.cache_id != b.cache_id
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — initialization
# ═══════════════════════════════════════════════════════════════════════════
class TestGeminiHarnessInit:
def test_default_model(self, harness):
assert harness.model == GEMINI_MODEL_PRIMARY
def test_custom_model(self):
h = GeminiHarness(api_key="key", model=GEMINI_MODEL_TERTIARY)
assert h.model == GEMINI_MODEL_TERTIARY
def test_session_id_generated(self, harness):
assert harness.session_id
assert len(harness.session_id) == 8
def test_no_api_key_warning(self, caplog):
import logging
with caplog.at_level(logging.WARNING, logger="gemini"):
GeminiHarness(api_key="")
assert "GOOGLE_API_KEY" in caplog.text
def test_no_api_key_returns_error_response(self):
h = GeminiHarness(api_key="")
resp = h.generate("hello")
assert resp.error is not None
assert "GOOGLE_API_KEY" in resp.error
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — context caching
# ═══════════════════════════════════════════════════════════════════════════
class TestContextCaching:
def test_set_context(self, harness):
harness.set_context("Project context here", ttl_seconds=600.0)
status = harness.context_status()
assert status["cached"] is True
assert status["valid"] is True
assert status["content_length"] == len("Project context here")
def test_clear_context(self, harness_with_context):
harness_with_context.clear_context()
assert harness_with_context.context_status()["cached"] is False
def test_context_injected_in_messages(self, harness_with_context):
messages = harness_with_context._build_messages("Hello", use_cache=True)
contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str))
assert "Timmy is sovereign" in contents
def test_context_skipped_when_use_cache_false(self, harness_with_context):
messages = harness_with_context._build_messages("Hello", use_cache=False)
contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str))
assert "Timmy is sovereign" not in contents
def test_expired_context_not_injected(self, harness):
harness.set_context("expired ctx", ttl_seconds=0.001)
time.sleep(0.01)
messages = harness._build_messages("Hello", use_cache=True)
contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str))
assert "expired ctx" not in contents
def test_cache_hit_count_increments(self, harness_with_context):
harness_with_context._build_messages("q1", use_cache=True)
harness_with_context._build_messages("q2", use_cache=True)
assert harness_with_context._context_cache.hit_count == 2
def test_context_status_no_cache(self, harness):
status = harness.context_status()
assert status == {"cached": False}
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — cost estimation
# ═══════════════════════════════════════════════════════════════════════════
class TestCostEstimation:
def test_cost_zero_tokens(self, harness):
cost = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 0, 0)
assert cost == 0.0
def test_cost_primary_model(self, harness):
cost = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 1_000_000, 1_000_000)
expected = COST_PER_1M_INPUT[GEMINI_MODEL_PRIMARY] + COST_PER_1M_OUTPUT[GEMINI_MODEL_PRIMARY]
assert abs(cost - expected) < 0.0001
def test_cost_tertiary_cheaper_than_primary(self, harness):
cost_primary = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 100_000, 100_000)
cost_tertiary = harness._estimate_cost(GEMINI_MODEL_TERTIARY, 100_000, 100_000)
assert cost_tertiary < cost_primary
def test_fallback_chain_order(self):
assert MODEL_FALLBACK_CHAIN[0] == GEMINI_MODEL_PRIMARY
assert MODEL_FALLBACK_CHAIN[1] == GEMINI_MODEL_SECONDARY
assert MODEL_FALLBACK_CHAIN[2] == GEMINI_MODEL_TERTIARY
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — generate (mocked HTTP)
# ═══════════════════════════════════════════════════════════════════════════
class TestGenerate:
def test_generate_success(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response):
resp = harness.generate("Hello Timmy")
assert resp.error is None
assert resp.text == "Hello from Gemini"
assert resp.input_tokens == 10
assert resp.output_tokens == 5
assert resp.model == GEMINI_MODEL_PRIMARY
def test_generate_uses_fallback_on_error(self, harness, mock_ok_response, mock_error_response):
"""First model fails, second succeeds."""
call_count = [0]
def side_effect(*args, **kwargs):
call_count[0] += 1
if call_count[0] == 1:
return mock_error_response
return mock_ok_response
with patch("requests.post", side_effect=side_effect):
resp = harness.generate("Hello")
assert resp.error is None
assert call_count[0] == 2
assert resp.model == GEMINI_MODEL_SECONDARY
def test_generate_all_fail_returns_error(self, harness, mock_error_response):
with patch("requests.post", return_value=mock_error_response):
resp = harness.generate("Hello")
assert resp.error is not None
assert "failed" in resp.error.lower()
def test_generate_updates_session_stats(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response):
harness.generate("q1")
harness.generate("q2")
assert harness.request_count == 2
assert harness.total_input_tokens == 20
assert harness.total_output_tokens == 10
def test_generate_with_system_prompt(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate("Hello", system="You are helpful")
call_kwargs = mock_post.call_args
payload = call_kwargs[1]["json"]
roles = [m["role"] for m in payload["messages"]]
assert "system" in roles
def test_generate_string_prompt_wrapped(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate("Test prompt")
payload = mock_post.call_args[1]["json"]
user_msgs = [m for m in payload["messages"] if m["role"] == "user"]
assert len(user_msgs) == 1
assert user_msgs[0]["content"] == "Test prompt"
def test_generate_list_prompt_passed_through(self, harness, mock_ok_response):
messages = [
{"role": "user", "content": "first"},
{"role": "assistant", "content": "reply"},
{"role": "user", "content": "follow up"},
]
with patch("requests.post", return_value=mock_ok_response):
resp = harness.generate(messages)
assert resp.error is None
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — generate_code
# ═══════════════════════════════════════════════════════════════════════════
class TestGenerateCode:
def test_generate_code_success(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response):
resp = harness.generate_code("write a hello world", language="python")
assert resp.error is None
assert resp.text == "Hello from Gemini"
def test_generate_code_injects_system(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate_code("fizzbuzz", language="go")
payload = mock_post.call_args[1]["json"]
system_msgs = [m for m in payload["messages"] if m["role"] == "system"]
assert any("go" in m["content"].lower() for m in system_msgs)
def test_generate_code_with_context(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate_code("extend this", context="def foo(): pass")
payload = mock_post.call_args[1]["json"]
user_msgs = [m for m in payload["messages"] if m["role"] == "user"]
assert "foo" in user_msgs[0]["content"]
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — generate_multimodal
# ═══════════════════════════════════════════════════════════════════════════
class TestGenerateMultimodal:
def test_multimodal_text_only(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response):
resp = harness.generate_multimodal("Describe this")
assert resp.error is None
def test_multimodal_with_base64_image(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate_multimodal(
"What is in this image?",
images=[{"type": "base64", "data": "abc123", "mime": "image/jpeg"}],
)
payload = mock_post.call_args[1]["json"]
content = payload["messages"][0]["content"]
image_parts = [p for p in content if p.get("type") == "image_url"]
assert len(image_parts) == 1
assert "data:image/jpeg;base64,abc123" in image_parts[0]["image_url"]["url"]
def test_multimodal_with_url_image(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response) as mock_post:
harness.generate_multimodal(
"What is this?",
images=[{"type": "url", "url": "http://example.com/img.png"}],
)
payload = mock_post.call_args[1]["json"]
content = payload["messages"][0]["content"]
image_parts = [p for p in content if p.get("type") == "image_url"]
assert image_parts[0]["image_url"]["url"] == "http://example.com/img.png"
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — session stats
# ═══════════════════════════════════════════════════════════════════════════
class TestSessionStats:
def test_session_stats_initial(self, harness):
stats = harness._session_stats()
assert stats["request_count"] == 0
assert stats["total_input_tokens"] == 0
assert stats["total_output_tokens"] == 0
assert stats["total_cost_usd"] == 0.0
assert stats["session_id"] == harness.session_id
def test_session_stats_after_calls(self, harness, mock_ok_response):
with patch("requests.post", return_value=mock_ok_response):
harness.generate("a")
harness.generate("b")
stats = harness._session_stats()
assert stats["request_count"] == 2
assert stats["total_input_tokens"] == 20
assert stats["total_output_tokens"] == 10
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — orchestration registration
# ═══════════════════════════════════════════════════════════════════════════
class TestOrchestrationRegistration:
def test_register_success(self, harness):
mock_resp = MagicMock()
mock_resp.status_code = 201
with patch("requests.post", return_value=mock_resp):
result = harness.register_in_orchestration("http://localhost:8000/api/v1/workers/register")
assert result is True
def test_register_failure_returns_false(self, harness):
mock_resp = MagicMock()
mock_resp.status_code = 500
mock_resp.text = "Internal error"
with patch("requests.post", return_value=mock_resp):
result = harness.register_in_orchestration("http://localhost:8000/api/v1/workers/register")
assert result is False
def test_register_connection_error_returns_false(self, harness):
with patch("requests.post", side_effect=Exception("Connection refused")):
result = harness.register_in_orchestration("http://localhost:9999/register")
assert result is False
def test_register_payload_contains_capabilities(self, harness):
mock_resp = MagicMock()
mock_resp.status_code = 200
with patch("requests.post", return_value=mock_resp) as mock_post:
harness.register_in_orchestration("http://localhost/register")
payload = mock_post.call_args[1]["json"]
assert payload["worker_id"] == HARNESS_ID
assert "text" in payload["capabilities"]
assert "multimodal" in payload["capabilities"]
assert "streaming" in payload["capabilities"]
assert "code" in payload["capabilities"]
assert len(payload["fallback_chain"]) == 3
# ═══════════════════════════════════════════════════════════════════════════
# GeminiHarness — async lifecycle (Hermes WS)
# ═══════════════════════════════════════════════════════════════════════════
class TestAsyncLifecycle:
@pytest.mark.asyncio
async def test_start_without_hermes(self, harness):
"""Start should succeed even if Hermes is not reachable."""
harness.hermes_ws_url = "ws://localhost:19999/ws"
# Should not raise
await harness.start()
assert harness._ws_connected is False
@pytest.mark.asyncio
async def test_stop_without_connection(self, harness):
"""Stop should succeed gracefully with no WS connection."""
await harness.stop()
# ═══════════════════════════════════════════════════════════════════════════
# HTTP server smoke test
# ═══════════════════════════════════════════════════════════════════════════
class TestHTTPServer:
def test_create_app_returns_classes(self, harness):
from nexus.gemini_harness import create_app
HTTPServer, GeminiHandler = create_app(harness)
assert HTTPServer is not None
assert GeminiHandler is not None
def test_health_handler(self, harness):
"""Verify health endpoint handler logic via direct method call."""
from nexus.gemini_harness import create_app
_, GeminiHandler = create_app(harness)
# Instantiate handler without a real socket
handler = GeminiHandler.__new__(GeminiHandler)
# _send_json should produce correct output
responses = []
handler._send_json = lambda data, status=200: responses.append((status, data))
handler.path = "/health"
handler.do_GET()
assert len(responses) == 1
assert responses[0][0] == 200
assert responses[0][1]["status"] == "ok"
assert responses[0][1]["harness"] == HARNESS_ID
def test_status_handler(self, harness, mock_ok_response):
from nexus.gemini_harness import create_app
_, GeminiHandler = create_app(harness)
handler = GeminiHandler.__new__(GeminiHandler)
responses = []
handler._send_json = lambda data, status=200: responses.append((status, data))
handler.path = "/status"
handler.do_GET()
assert responses[0][1]["request_count"] == 0
assert responses[0][1]["model"] == harness.model
def test_unknown_get_returns_404(self, harness):
from nexus.gemini_harness import create_app
_, GeminiHandler = create_app(harness)
handler = GeminiHandler.__new__(GeminiHandler)
responses = []
handler._send_json = lambda data, status=200: responses.append((status, data))
handler.path = "/nonexistent"
handler.do_GET()
assert responses[0][0] == 404
# ═══════════════════════════════════════════════════════════════════════════
# Live API tests (skipped unless RUN_LIVE_TESTS=1 and GOOGLE_API_KEY set)
# ═══════════════════════════════════════════════════════════════════════════
def _live_tests_enabled():
return (
os.environ.get("RUN_LIVE_TESTS") == "1"
and bool(os.environ.get("GOOGLE_API_KEY"))
)
@pytest.mark.skipif(
not _live_tests_enabled(),
reason="Live tests require RUN_LIVE_TESTS=1 and GOOGLE_API_KEY",
)
class TestLiveAPI:
"""Integration tests that hit the real Gemini API."""
@pytest.fixture
def live_harness(self):
return GeminiHarness()
def test_live_generate(self, live_harness):
resp = live_harness.generate("Say 'pong' and nothing else.")
assert resp.error is None
assert resp.text.strip().lower().startswith("pong")
assert resp.input_tokens > 0
assert resp.latency_ms > 0
def test_live_generate_code(self, live_harness):
resp = live_harness.generate_code("write a function that returns 42", language="python")
assert resp.error is None
assert "42" in resp.text
def test_live_stream(self, live_harness):
chunks = list(live_harness.stream_generate("Count to 3: one, two, three."))
assert len(chunks) > 0
if __name__ == "__main__":
pytest.main([__file__, "-v"])