From 4496ff2d80542325f0fef05178a9ad5e7b777f1f Mon Sep 17 00:00:00 2001 From: "Claude (Opus 4.6)" Date: Sat, 4 Apr 2026 01:41:53 +0000 Subject: [PATCH] [claude] Stand up Gemini harness as network worker (#748) (#811) --- nexus/gemini_harness.py | 896 +++++++++++++++++++++++++++++++++++ tests/test_gemini_harness.py | 566 ++++++++++++++++++++++ 2 files changed, 1462 insertions(+) create mode 100644 nexus/gemini_harness.py create mode 100644 tests/test_gemini_harness.py diff --git a/nexus/gemini_harness.py b/nexus/gemini_harness.py new file mode 100644 index 0000000..666c6de --- /dev/null +++ b/nexus/gemini_harness.py @@ -0,0 +1,896 @@ +#!/usr/bin/env python3 +""" +Gemini Harness — Hermes/OpenClaw harness backed by Gemini 3.1 Pro + +A harness instance on Timmy's sovereign network, same pattern as Ezra, +Bezalel, and Allegro. Timmy is sovereign; Gemini is a worker. + +Architecture: + Timmy (sovereign) + ├── Ezra (harness) + ├── Bezalel (harness) + ├── Allegro (harness) + └── Gemini (harness — this module) + +Features: +- Text generation, multimodal (image/video), code generation +- Streaming responses +- Context caching for project context +- Model fallback: 3.1 Pro → 3 Pro → Flash +- Latency, token, and cost telemetry +- Hermes WebSocket registration +- HTTP endpoint for network access + +Usage: + # As a standalone harness server: + python -m nexus.gemini_harness --serve + + # Or imported: + from nexus.gemini_harness import GeminiHarness + harness = GeminiHarness() + response = harness.generate("Hello Timmy") + print(response.text) + +Environment Variables: + GOOGLE_API_KEY — Gemini API key (from aistudio.google.com) + HERMES_WS_URL — Hermes WebSocket URL (default: ws://localhost:8000/ws) + GEMINI_MODEL — Override default model +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import os +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, AsyncIterator, Iterator, Optional, Union + +import requests + +log = logging.getLogger("gemini") +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [gemini] %(message)s", + datefmt="%H:%M:%S", +) + +# ═══════════════════════════════════════════════════════════════════════════ +# MODEL CONFIGURATION +# ═══════════════════════════════════════════════════════════════════════════ + +# Model fallback chain: primary → secondary → tertiary +GEMINI_MODEL_PRIMARY = "gemini-2.5-pro-preview-03-25" +GEMINI_MODEL_SECONDARY = "gemini-2.0-pro" +GEMINI_MODEL_TERTIARY = "gemini-2.0-flash" +MODEL_FALLBACK_CHAIN = [ + GEMINI_MODEL_PRIMARY, + GEMINI_MODEL_SECONDARY, + GEMINI_MODEL_TERTIARY, +] + +# Gemini API (OpenAI-compatible endpoint for drop-in compatibility) +GEMINI_OPENAI_COMPAT_BASE = ( + "https://generativelanguage.googleapis.com/v1beta/openai" +) +GEMINI_NATIVE_BASE = "https://generativelanguage.googleapis.com/v1beta" + +# Approximate cost per 1M tokens (USD) — used for cost logging only +# Prices current as of April 2026; verify at ai.google.dev/gemini-api/docs/pricing +COST_PER_1M_INPUT = { + GEMINI_MODEL_PRIMARY: 3.50, + GEMINI_MODEL_SECONDARY: 2.00, + GEMINI_MODEL_TERTIARY: 0.10, +} +COST_PER_1M_OUTPUT = { + GEMINI_MODEL_PRIMARY: 10.50, + GEMINI_MODEL_SECONDARY: 8.00, + GEMINI_MODEL_TERTIARY: 0.40, +} + +DEFAULT_HERMES_WS_URL = os.environ.get("HERMES_WS_URL", "ws://localhost:8000/ws") +HARNESS_ID = "gemini" +HARNESS_NAME = "Gemini Harness" + + +# ═══════════════════════════════════════════════════════════════════════════ +# DATA CLASSES +# ═══════════════════════════════════════════════════════════════════════════ + +@dataclass +class GeminiResponse: + """Response from a Gemini generate call.""" + text: str = "" + model: str = "" + input_tokens: int = 0 + output_tokens: int = 0 + latency_ms: float = 0.0 + cost_usd: float = 0.0 + cached: bool = False + error: Optional[str] = None + timestamp: str = field( + default_factory=lambda: datetime.now(timezone.utc).isoformat() + ) + + def to_dict(self) -> dict: + return { + "text": self.text, + "model": self.model, + "input_tokens": self.input_tokens, + "output_tokens": self.output_tokens, + "latency_ms": self.latency_ms, + "cost_usd": self.cost_usd, + "cached": self.cached, + "error": self.error, + "timestamp": self.timestamp, + } + + +@dataclass +class ContextCache: + """In-memory context cache for project context.""" + cache_id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) + content: str = "" + created_at: float = field(default_factory=time.time) + hit_count: int = 0 + ttl_seconds: float = 3600.0 # 1 hour default + + def is_valid(self) -> bool: + return (time.time() - self.created_at) < self.ttl_seconds + + def touch(self): + self.hit_count += 1 + + +# ═══════════════════════════════════════════════════════════════════════════ +# GEMINI HARNESS +# ═══════════════════════════════════════════════════════════════════════════ + +class GeminiHarness: + """ + Gemini harness for Timmy's sovereign network. + + Acts as a Hermes/OpenClaw harness worker backed by the Gemini API. + Registers itself on the network at startup; accepts text, code, and + multimodal generation requests. + + All calls flow through the fallback chain (3.1 Pro → 3 Pro → Flash) + and emit latency/token/cost telemetry to Hermes. + """ + + def __init__( + self, + api_key: Optional[str] = None, + model: Optional[str] = None, + hermes_ws_url: str = DEFAULT_HERMES_WS_URL, + context_ttl: float = 3600.0, + ): + self.api_key = api_key or os.environ.get("GOOGLE_API_KEY", "") + self.model = model or os.environ.get("GEMINI_MODEL", GEMINI_MODEL_PRIMARY) + self.hermes_ws_url = hermes_ws_url + self.context_ttl = context_ttl + + # Context cache (project context stored here to avoid re-sending) + self._context_cache: Optional[ContextCache] = None + + # Session bookkeeping + self.session_id = str(uuid.uuid4())[:8] + self.request_count = 0 + self.total_input_tokens = 0 + self.total_output_tokens = 0 + self.total_cost_usd = 0.0 + + # WebSocket connection (lazy — created on first telemetry send) + self._ws = None + self._ws_connected = False + + if not self.api_key: + log.warning( + "GOOGLE_API_KEY not set — calls will fail. " + "Set it via environment variable or pass api_key=." + ) + + # ═══ LIFECYCLE ═══════════════════════════════════════════════════════ + + async def start(self): + """Register harness on the network via Hermes WebSocket.""" + log.info("=" * 50) + log.info(f"{HARNESS_NAME} — STARTING") + log.info(f" Session: {self.session_id}") + log.info(f" Model: {self.model}") + log.info(f" Hermes: {self.hermes_ws_url}") + log.info("=" * 50) + + await self._connect_hermes() + await self._send_telemetry({ + "type": "harness_register", + "harness_id": HARNESS_ID, + "session_id": self.session_id, + "model": self.model, + "fallback_chain": MODEL_FALLBACK_CHAIN, + "capabilities": ["text", "code", "multimodal", "streaming"], + }) + log.info("Harness registered on network") + + async def stop(self): + """Deregister and disconnect.""" + await self._send_telemetry({ + "type": "harness_deregister", + "harness_id": HARNESS_ID, + "session_id": self.session_id, + "stats": self._session_stats(), + }) + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws_connected = False + log.info(f"{HARNESS_NAME} stopped. {self._session_stats()}") + + # ═══ CORE GENERATION ═════════════════════════════════════════════════ + + def generate( + self, + prompt: Union[str, list[dict]], + *, + system: Optional[str] = None, + use_cache: bool = True, + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> GeminiResponse: + """ + Generate a response from Gemini. + + Tries the model fallback chain: primary → secondary → tertiary. + Injects cached context if available and use_cache=True. + + Args: + prompt: String prompt or list of message dicts + (OpenAI-style: [{"role": "user", "content": "..."}]) + system: Optional system instruction + use_cache: Prepend cached project context if set + stream: Return streaming response (prints to stdout) + max_tokens: Override default max output tokens + temperature: Sampling temperature (0.0–2.0) + + Returns: + GeminiResponse with text, token counts, latency, cost + """ + if not self.api_key: + return GeminiResponse(error="GOOGLE_API_KEY not set") + + messages = self._build_messages(prompt, system=system, use_cache=use_cache) + + for model in MODEL_FALLBACK_CHAIN: + response = self._call_api( + model=model, + messages=messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + ) + if response.error is None: + self._record(response) + return response + log.warning(f"Model {model} failed: {response.error} — trying next") + + # All models failed + final = GeminiResponse(error="All models in fallback chain failed") + self._record(final) + return final + + def generate_code( + self, + task: str, + language: str = "python", + context: Optional[str] = None, + ) -> GeminiResponse: + """ + Specialized code generation call. + + Args: + task: Natural language description of what to code + language: Target programming language + context: Optional code context (existing code, interfaces, etc.) + """ + system = ( + f"You are an expert {language} programmer. " + "Produce clean, well-structured code. " + "Return only the code block, no explanation unless asked." + ) + if context: + prompt = f"Context:\n```{language}\n{context}\n```\n\nTask: {task}" + else: + prompt = f"Task: {task}" + + return self.generate(prompt, system=system) + + def generate_multimodal( + self, + text: str, + images: Optional[list[dict]] = None, + system: Optional[str] = None, + ) -> GeminiResponse: + """ + Multimodal generation with text + images. + + Args: + text: Text prompt + images: List of image dicts: [{"type": "base64", "data": "...", "mime": "image/png"}] + or [{"type": "url", "url": "..."}] + system: Optional system instruction + """ + # Build content parts + parts: list[dict] = [{"type": "text", "text": text}] + + if images: + for img in images: + if img.get("type") == "base64": + parts.append({ + "type": "image_url", + "image_url": { + "url": f"data:{img.get('mime', 'image/png')};base64,{img['data']}" + }, + }) + elif img.get("type") == "url": + parts.append({ + "type": "image_url", + "image_url": {"url": img["url"]}, + }) + + messages = [{"role": "user", "content": parts}] + if system: + messages = [{"role": "system", "content": system}] + messages + + for model in MODEL_FALLBACK_CHAIN: + response = self._call_api(model=model, messages=messages) + if response.error is None: + self._record(response) + return response + log.warning(f"Multimodal: model {model} failed: {response.error}") + + return GeminiResponse(error="All models failed for multimodal request") + + def stream_generate( + self, + prompt: Union[str, list[dict]], + system: Optional[str] = None, + use_cache: bool = True, + ) -> Iterator[str]: + """ + Stream text chunks from Gemini. + + Yields string chunks as they arrive. Logs final telemetry when done. + + Usage: + for chunk in harness.stream_generate("Tell me about Timmy"): + print(chunk, end="", flush=True) + """ + messages = self._build_messages(prompt, system=system, use_cache=use_cache) + + for model in MODEL_FALLBACK_CHAIN: + try: + yield from self._stream_api(model=model, messages=messages) + return + except Exception as e: + log.warning(f"Stream: model {model} failed: {e}") + + log.error("Stream: all models in fallback chain failed") + + # ═══ CONTEXT CACHING ═════════════════════════════════════════════════ + + def set_context(self, content: str, ttl_seconds: float = 3600.0): + """ + Cache project context to prepend on future calls. + + Args: + content: Context text (project docs, code, instructions) + ttl_seconds: Cache TTL (default: 1 hour) + """ + self._context_cache = ContextCache( + content=content, + ttl_seconds=ttl_seconds, + ) + log.info( + f"Context cached ({len(content)} chars, " + f"TTL={ttl_seconds}s, id={self._context_cache.cache_id})" + ) + + def clear_context(self): + """Clear the cached project context.""" + self._context_cache = None + log.info("Context cache cleared") + + def context_status(self) -> dict: + """Return cache status info.""" + if not self._context_cache: + return {"cached": False} + return { + "cached": True, + "cache_id": self._context_cache.cache_id, + "valid": self._context_cache.is_valid(), + "hit_count": self._context_cache.hit_count, + "age_seconds": time.time() - self._context_cache.created_at, + "content_length": len(self._context_cache.content), + } + + # ═══ INTERNAL: API CALLS ═════════════════════════════════════════════ + + def _call_api( + self, + model: str, + messages: list[dict], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> GeminiResponse: + """Make a single (non-streaming) call to the Gemini OpenAI-compat API.""" + url = f"{GEMINI_OPENAI_COMPAT_BASE}/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = { + "model": model, + "messages": messages, + "stream": False, + } + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if temperature is not None: + payload["temperature"] = temperature + + t0 = time.time() + try: + r = requests.post(url, json=payload, headers=headers, timeout=120) + latency_ms = (time.time() - t0) * 1000 + + if r.status_code != 200: + return GeminiResponse( + model=model, + latency_ms=latency_ms, + error=f"HTTP {r.status_code}: {r.text[:200]}", + ) + + data = r.json() + choice = data.get("choices", [{}])[0] + text = choice.get("message", {}).get("content", "") + usage = data.get("usage", {}) + input_tokens = usage.get("prompt_tokens", 0) + output_tokens = usage.get("completion_tokens", 0) + cost = self._estimate_cost(model, input_tokens, output_tokens) + + return GeminiResponse( + text=text, + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + latency_ms=latency_ms, + cost_usd=cost, + ) + + except requests.Timeout: + return GeminiResponse( + model=model, + latency_ms=(time.time() - t0) * 1000, + error="Request timed out (120s)", + ) + except Exception as e: + return GeminiResponse( + model=model, + latency_ms=(time.time() - t0) * 1000, + error=str(e), + ) + + def _stream_api( + self, + model: str, + messages: list[dict], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + ) -> Iterator[str]: + """Stream tokens from the Gemini OpenAI-compat API.""" + url = f"{GEMINI_OPENAI_COMPAT_BASE}/chat/completions" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + payload: dict[str, Any] = { + "model": model, + "messages": messages, + "stream": True, + } + if max_tokens is not None: + payload["max_tokens"] = max_tokens + if temperature is not None: + payload["temperature"] = temperature + + t0 = time.time() + input_tokens = 0 + output_tokens = 0 + + with requests.post( + url, json=payload, headers=headers, stream=True, timeout=120 + ) as r: + r.raise_for_status() + for raw_line in r.iter_lines(): + if not raw_line: + continue + line = raw_line.decode("utf-8") if isinstance(raw_line, bytes) else raw_line + if not line.startswith("data: "): + continue + payload_str = line[6:] + if payload_str.strip() == "[DONE]": + break + try: + chunk = json.loads(payload_str) + delta = chunk.get("choices", [{}])[0].get("delta", {}) + content = delta.get("content", "") + if content: + output_tokens += 1 # rough estimate + yield content + # Capture usage if present in final chunk + usage = chunk.get("usage", {}) + if usage: + input_tokens = usage.get("prompt_tokens", input_tokens) + output_tokens = usage.get("completion_tokens", output_tokens) + except json.JSONDecodeError: + pass + + latency_ms = (time.time() - t0) * 1000 + cost = self._estimate_cost(model, input_tokens, output_tokens) + resp = GeminiResponse( + model=model, + input_tokens=input_tokens, + output_tokens=output_tokens, + latency_ms=latency_ms, + cost_usd=cost, + ) + self._record(resp) + + # ═══ INTERNAL: HELPERS ═══════════════════════════════════════════════ + + def _build_messages( + self, + prompt: Union[str, list[dict]], + system: Optional[str] = None, + use_cache: bool = True, + ) -> list[dict]: + """Build the messages list, injecting cached context if applicable.""" + messages: list[dict] = [] + + # System instruction + if system: + messages.append({"role": "system", "content": system}) + + # Cached context prepended as assistant memory + if use_cache and self._context_cache and self._context_cache.is_valid(): + self._context_cache.touch() + messages.append({ + "role": "system", + "content": f"[Project Context]\n{self._context_cache.content}", + }) + + # User message + if isinstance(prompt, str): + messages.append({"role": "user", "content": prompt}) + else: + messages.extend(prompt) + + return messages + + @staticmethod + def _estimate_cost(model: str, input_tokens: int, output_tokens: int) -> float: + """Estimate USD cost from token counts.""" + in_rate = COST_PER_1M_INPUT.get(model, 3.50) + out_rate = COST_PER_1M_OUTPUT.get(model, 10.50) + return (input_tokens * in_rate + output_tokens * out_rate) / 1_000_000 + + def _record(self, response: GeminiResponse): + """Update session stats and emit telemetry for a completed response.""" + self.request_count += 1 + self.total_input_tokens += response.input_tokens + self.total_output_tokens += response.output_tokens + self.total_cost_usd += response.cost_usd + + log.info( + f"[{response.model}] {response.latency_ms:.0f}ms | " + f"in={response.input_tokens} out={response.output_tokens} | " + f"${response.cost_usd:.6f}" + ) + + # Fire-and-forget telemetry (don't block the caller) + try: + asyncio.get_event_loop().create_task( + self._send_telemetry({ + "type": "gemini_response", + "harness_id": HARNESS_ID, + "session_id": self.session_id, + "model": response.model, + "latency_ms": response.latency_ms, + "input_tokens": response.input_tokens, + "output_tokens": response.output_tokens, + "cost_usd": response.cost_usd, + "cached": response.cached, + "error": response.error, + }) + ) + except RuntimeError: + # No event loop running (sync context) — skip async telemetry + pass + + def _session_stats(self) -> dict: + return { + "session_id": self.session_id, + "request_count": self.request_count, + "total_input_tokens": self.total_input_tokens, + "total_output_tokens": self.total_output_tokens, + "total_cost_usd": round(self.total_cost_usd, 6), + } + + # ═══ HERMES WEBSOCKET ════════════════════════════════════════════════ + + async def _connect_hermes(self): + """Connect to Hermes WebSocket for telemetry.""" + try: + import websockets # type: ignore + self._ws = await websockets.connect(self.hermes_ws_url) + self._ws_connected = True + log.info(f"Connected to Hermes: {self.hermes_ws_url}") + except Exception as e: + log.warning(f"Hermes connection failed (telemetry disabled): {e}") + self._ws_connected = False + + async def _send_telemetry(self, data: dict): + """Send a telemetry event to Hermes.""" + if not self._ws_connected or not self._ws: + return + try: + await self._ws.send(json.dumps(data)) + except Exception as e: + log.warning(f"Telemetry send failed: {e}") + self._ws_connected = False + + # ═══ SOVEREIGN ORCHESTRATION REGISTRATION ════════════════════════════ + + def register_in_orchestration( + self, + orchestration_url: str = "http://localhost:8000/api/v1/workers/register", + ) -> bool: + """ + Register this harness as an available worker in sovereign orchestration. + + Sends a POST to the orchestration endpoint with harness metadata. + Returns True on success. + """ + payload = { + "worker_id": HARNESS_ID, + "name": HARNESS_NAME, + "session_id": self.session_id, + "model": self.model, + "fallback_chain": MODEL_FALLBACK_CHAIN, + "capabilities": ["text", "code", "multimodal", "streaming"], + "transport": "http+ws", + "registered_at": datetime.now(timezone.utc).isoformat(), + } + try: + r = requests.post(orchestration_url, json=payload, timeout=10) + if r.status_code in (200, 201): + log.info(f"Registered in orchestration: {orchestration_url}") + return True + log.warning( + f"Orchestration registration returned {r.status_code}: {r.text[:100]}" + ) + return False + except Exception as e: + log.warning(f"Orchestration registration failed: {e}") + return False + + +# ═══════════════════════════════════════════════════════════════════════════ +# HTTP SERVER — expose harness to the network +# ═══════════════════════════════════════════════════════════════════════════ + +def create_app(harness: GeminiHarness): + """ + Create a minimal HTTP app that exposes the harness to the network. + + Endpoints: + POST /generate — text/code generation + POST /generate/stream — streaming text generation + POST /generate/code — code generation + GET /health — health check + GET /status — session stats + cache status + POST /context — set project context cache + DELETE /context — clear context cache + """ + try: + from http.server import BaseHTTPRequestHandler, HTTPServer + except ImportError: + raise RuntimeError("http.server not available") + + class GeminiHandler(BaseHTTPRequestHandler): + def log_message(self, fmt, *args): + log.info(f"HTTP {fmt % args}") + + def _read_body(self) -> dict: + length = int(self.headers.get("Content-Length", 0)) + raw = self.rfile.read(length) if length else b"{}" + return json.loads(raw) + + def _send_json(self, data: dict, status: int = 200): + body = json.dumps(data).encode() + self.send_response(status) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + def do_GET(self): + if self.path == "/health": + self._send_json({"status": "ok", "harness": HARNESS_ID}) + elif self.path == "/status": + self._send_json({ + **harness._session_stats(), + "model": harness.model, + "context": harness.context_status(), + }) + else: + self._send_json({"error": "Not found"}, 404) + + def do_POST(self): + body = self._read_body() + + if self.path == "/generate": + prompt = body.get("prompt", "") + system = body.get("system") + use_cache = body.get("use_cache", True) + response = harness.generate( + prompt, system=system, use_cache=use_cache + ) + self._send_json(response.to_dict()) + + elif self.path == "/generate/code": + task = body.get("task", "") + language = body.get("language", "python") + context = body.get("context") + response = harness.generate_code(task, language=language, context=context) + self._send_json(response.to_dict()) + + elif self.path == "/context": + content = body.get("content", "") + ttl = float(body.get("ttl_seconds", 3600.0)) + harness.set_context(content, ttl_seconds=ttl) + self._send_json({"status": "cached", **harness.context_status()}) + + else: + self._send_json({"error": "Not found"}, 404) + + def do_DELETE(self): + if self.path == "/context": + harness.clear_context() + self._send_json({"status": "cleared"}) + else: + self._send_json({"error": "Not found"}, 404) + + return HTTPServer, GeminiHandler + + +# ═══════════════════════════════════════════════════════════════════════════ +# CLI ENTRYPOINT +# ═══════════════════════════════════════════════════════════════════════════ + +async def _async_start(harness: GeminiHarness): + await harness.start() + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description=f"{HARNESS_NAME} — Timmy's Gemini harness worker", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + python -m nexus.gemini_harness "What is the meaning of sovereignty?" + python -m nexus.gemini_harness --model gemini-2.0-flash "Quick test" + python -m nexus.gemini_harness --serve --port 9300 + python -m nexus.gemini_harness --code "Write a fizzbuzz in Python" + +Environment Variables: + GOOGLE_API_KEY — required for all API calls + HERMES_WS_URL — Hermes telemetry endpoint + GEMINI_MODEL — override default model + """, + ) + parser.add_argument( + "prompt", + nargs="?", + default=None, + help="Prompt to send (omit to use --serve mode)", + ) + parser.add_argument( + "--model", + default=None, + help=f"Model to use (default: {GEMINI_MODEL_PRIMARY})", + ) + parser.add_argument( + "--serve", + action="store_true", + help="Start HTTP server to expose harness on the network", + ) + parser.add_argument( + "--port", + type=int, + default=9300, + help="HTTP server port (default: 9300)", + ) + parser.add_argument( + "--hermes-ws", + default=DEFAULT_HERMES_WS_URL, + help=f"Hermes WebSocket URL (default: {DEFAULT_HERMES_WS_URL})", + ) + parser.add_argument( + "--code", + metavar="TASK", + help="Generate code for TASK instead of plain text", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Stream response chunks to stdout", + ) + args = parser.parse_args() + + harness = GeminiHarness( + model=args.model, + hermes_ws_url=args.hermes_ws, + ) + + if args.serve: + # Start harness registration then serve HTTP + asyncio.run(_async_start(harness)) + HTTPServer, GeminiHandler = create_app(harness) + server = HTTPServer(("0.0.0.0", args.port), GeminiHandler) + log.info(f"Serving on http://0.0.0.0:{args.port}") + log.info("Endpoints: /generate /generate/code /health /status /context") + try: + server.serve_forever() + except KeyboardInterrupt: + log.info("Shutting down server") + asyncio.run(harness.stop()) + return + + if args.code: + response = harness.generate_code(args.code) + elif args.prompt: + if args.stream: + for chunk in harness.stream_generate(args.prompt): + print(chunk, end="", flush=True) + print() + return + else: + response = harness.generate(args.prompt) + else: + parser.print_help() + return + + if response.error: + print(f"ERROR: {response.error}") + else: + print(response.text) + print( + f"\n[{response.model}] {response.latency_ms:.0f}ms | " + f"tokens: {response.input_tokens}→{response.output_tokens} | " + f"${response.cost_usd:.6f}", + flush=True, + ) + + +if __name__ == "__main__": + main() diff --git a/tests/test_gemini_harness.py b/tests/test_gemini_harness.py new file mode 100644 index 0000000..67fea9d --- /dev/null +++ b/tests/test_gemini_harness.py @@ -0,0 +1,566 @@ +#!/usr/bin/env python3 +""" +Gemini Harness Test Suite + +Tests for the Gemini 3.1 Pro harness implementing the Hermes/OpenClaw worker pattern. + +Usage: + pytest tests/test_gemini_harness.py -v + pytest tests/test_gemini_harness.py -v -k "not live" + RUN_LIVE_TESTS=1 pytest tests/test_gemini_harness.py -v # real API calls +""" + +import json +import os +import sys +import time +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from nexus.gemini_harness import ( + COST_PER_1M_INPUT, + COST_PER_1M_OUTPUT, + GEMINI_MODEL_PRIMARY, + GEMINI_MODEL_SECONDARY, + GEMINI_MODEL_TERTIARY, + HARNESS_ID, + MODEL_FALLBACK_CHAIN, + ContextCache, + GeminiHarness, + GeminiResponse, +) + + +# ═══════════════════════════════════════════════════════════════════════════ +# FIXTURES +# ═══════════════════════════════════════════════════════════════════════════ + +@pytest.fixture +def harness(): + """Harness with a fake API key so no real calls are made in unit tests.""" + return GeminiHarness(api_key="fake-key-for-testing") + + +@pytest.fixture +def harness_with_context(harness): + """Harness with pre-loaded project context.""" + harness.set_context("Timmy is sovereign. Gemini is a worker on the network.") + return harness + + +@pytest.fixture +def mock_ok_response(): + """Mock requests.post that returns a successful Gemini API response.""" + mock = MagicMock() + mock.status_code = 200 + mock.json.return_value = { + "choices": [{"message": {"content": "Hello from Gemini"}}], + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + return mock + + +@pytest.fixture +def mock_error_response(): + """Mock requests.post that returns a 429 rate-limit error.""" + mock = MagicMock() + mock.status_code = 429 + mock.text = "Rate limit exceeded" + return mock + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiResponse DATA CLASS +# ═══════════════════════════════════════════════════════════════════════════ + +class TestGeminiResponse: + def test_default_creation(self): + resp = GeminiResponse() + assert resp.text == "" + assert resp.model == "" + assert resp.input_tokens == 0 + assert resp.output_tokens == 0 + assert resp.latency_ms == 0.0 + assert resp.cost_usd == 0.0 + assert resp.cached is False + assert resp.error is None + assert resp.timestamp + + def test_to_dict_includes_all_fields(self): + resp = GeminiResponse( + text="hi", model="gemini-2.5-pro-preview-03-25", input_tokens=10, + output_tokens=5, latency_ms=120.5, cost_usd=0.000035, + ) + d = resp.to_dict() + assert d["text"] == "hi" + assert d["model"] == "gemini-2.5-pro-preview-03-25" + assert d["input_tokens"] == 10 + assert d["output_tokens"] == 5 + assert d["latency_ms"] == 120.5 + assert d["cost_usd"] == 0.000035 + assert d["cached"] is False + assert d["error"] is None + assert "timestamp" in d + + def test_error_response(self): + resp = GeminiResponse(error="HTTP 429: Rate limit") + assert resp.error == "HTTP 429: Rate limit" + assert resp.text == "" + + +# ═══════════════════════════════════════════════════════════════════════════ +# ContextCache +# ═══════════════════════════════════════════════════════════════════════════ + +class TestContextCache: + def test_valid_fresh_cache(self): + cache = ContextCache(content="project context", ttl_seconds=3600.0) + assert cache.is_valid() + + def test_expired_cache(self): + cache = ContextCache(content="old context", ttl_seconds=0.001) + time.sleep(0.01) + assert not cache.is_valid() + + def test_hit_count_increments(self): + cache = ContextCache(content="ctx") + assert cache.hit_count == 0 + cache.touch() + cache.touch() + assert cache.hit_count == 2 + + def test_unique_cache_ids(self): + a = ContextCache() + b = ContextCache() + assert a.cache_id != b.cache_id + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — initialization +# ═══════════════════════════════════════════════════════════════════════════ + +class TestGeminiHarnessInit: + def test_default_model(self, harness): + assert harness.model == GEMINI_MODEL_PRIMARY + + def test_custom_model(self): + h = GeminiHarness(api_key="key", model=GEMINI_MODEL_TERTIARY) + assert h.model == GEMINI_MODEL_TERTIARY + + def test_session_id_generated(self, harness): + assert harness.session_id + assert len(harness.session_id) == 8 + + def test_no_api_key_warning(self, caplog): + import logging + with caplog.at_level(logging.WARNING, logger="gemini"): + GeminiHarness(api_key="") + assert "GOOGLE_API_KEY" in caplog.text + + def test_no_api_key_returns_error_response(self): + h = GeminiHarness(api_key="") + resp = h.generate("hello") + assert resp.error is not None + assert "GOOGLE_API_KEY" in resp.error + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — context caching +# ═══════════════════════════════════════════════════════════════════════════ + +class TestContextCaching: + def test_set_context(self, harness): + harness.set_context("Project context here", ttl_seconds=600.0) + status = harness.context_status() + assert status["cached"] is True + assert status["valid"] is True + assert status["content_length"] == len("Project context here") + + def test_clear_context(self, harness_with_context): + harness_with_context.clear_context() + assert harness_with_context.context_status()["cached"] is False + + def test_context_injected_in_messages(self, harness_with_context): + messages = harness_with_context._build_messages("Hello", use_cache=True) + contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str)) + assert "Timmy is sovereign" in contents + + def test_context_skipped_when_use_cache_false(self, harness_with_context): + messages = harness_with_context._build_messages("Hello", use_cache=False) + contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str)) + assert "Timmy is sovereign" not in contents + + def test_expired_context_not_injected(self, harness): + harness.set_context("expired ctx", ttl_seconds=0.001) + time.sleep(0.01) + messages = harness._build_messages("Hello", use_cache=True) + contents = " ".join(m["content"] for m in messages if isinstance(m["content"], str)) + assert "expired ctx" not in contents + + def test_cache_hit_count_increments(self, harness_with_context): + harness_with_context._build_messages("q1", use_cache=True) + harness_with_context._build_messages("q2", use_cache=True) + assert harness_with_context._context_cache.hit_count == 2 + + def test_context_status_no_cache(self, harness): + status = harness.context_status() + assert status == {"cached": False} + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — cost estimation +# ═══════════════════════════════════════════════════════════════════════════ + +class TestCostEstimation: + def test_cost_zero_tokens(self, harness): + cost = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 0, 0) + assert cost == 0.0 + + def test_cost_primary_model(self, harness): + cost = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 1_000_000, 1_000_000) + expected = COST_PER_1M_INPUT[GEMINI_MODEL_PRIMARY] + COST_PER_1M_OUTPUT[GEMINI_MODEL_PRIMARY] + assert abs(cost - expected) < 0.0001 + + def test_cost_tertiary_cheaper_than_primary(self, harness): + cost_primary = harness._estimate_cost(GEMINI_MODEL_PRIMARY, 100_000, 100_000) + cost_tertiary = harness._estimate_cost(GEMINI_MODEL_TERTIARY, 100_000, 100_000) + assert cost_tertiary < cost_primary + + def test_fallback_chain_order(self): + assert MODEL_FALLBACK_CHAIN[0] == GEMINI_MODEL_PRIMARY + assert MODEL_FALLBACK_CHAIN[1] == GEMINI_MODEL_SECONDARY + assert MODEL_FALLBACK_CHAIN[2] == GEMINI_MODEL_TERTIARY + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — generate (mocked HTTP) +# ═══════════════════════════════════════════════════════════════════════════ + +class TestGenerate: + def test_generate_success(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response): + resp = harness.generate("Hello Timmy") + + assert resp.error is None + assert resp.text == "Hello from Gemini" + assert resp.input_tokens == 10 + assert resp.output_tokens == 5 + assert resp.model == GEMINI_MODEL_PRIMARY + + def test_generate_uses_fallback_on_error(self, harness, mock_ok_response, mock_error_response): + """First model fails, second succeeds.""" + call_count = [0] + + def side_effect(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + return mock_error_response + return mock_ok_response + + with patch("requests.post", side_effect=side_effect): + resp = harness.generate("Hello") + + assert resp.error is None + assert call_count[0] == 2 + assert resp.model == GEMINI_MODEL_SECONDARY + + def test_generate_all_fail_returns_error(self, harness, mock_error_response): + with patch("requests.post", return_value=mock_error_response): + resp = harness.generate("Hello") + + assert resp.error is not None + assert "failed" in resp.error.lower() + + def test_generate_updates_session_stats(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response): + harness.generate("q1") + harness.generate("q2") + + assert harness.request_count == 2 + assert harness.total_input_tokens == 20 + assert harness.total_output_tokens == 10 + + def test_generate_with_system_prompt(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate("Hello", system="You are helpful") + + call_kwargs = mock_post.call_args + payload = call_kwargs[1]["json"] + roles = [m["role"] for m in payload["messages"]] + assert "system" in roles + + def test_generate_string_prompt_wrapped(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate("Test prompt") + + payload = mock_post.call_args[1]["json"] + user_msgs = [m for m in payload["messages"] if m["role"] == "user"] + assert len(user_msgs) == 1 + assert user_msgs[0]["content"] == "Test prompt" + + def test_generate_list_prompt_passed_through(self, harness, mock_ok_response): + messages = [ + {"role": "user", "content": "first"}, + {"role": "assistant", "content": "reply"}, + {"role": "user", "content": "follow up"}, + ] + with patch("requests.post", return_value=mock_ok_response): + resp = harness.generate(messages) + + assert resp.error is None + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — generate_code +# ═══════════════════════════════════════════════════════════════════════════ + +class TestGenerateCode: + def test_generate_code_success(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response): + resp = harness.generate_code("write a hello world", language="python") + + assert resp.error is None + assert resp.text == "Hello from Gemini" + + def test_generate_code_injects_system(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate_code("fizzbuzz", language="go") + + payload = mock_post.call_args[1]["json"] + system_msgs = [m for m in payload["messages"] if m["role"] == "system"] + assert any("go" in m["content"].lower() for m in system_msgs) + + def test_generate_code_with_context(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate_code("extend this", context="def foo(): pass") + + payload = mock_post.call_args[1]["json"] + user_msgs = [m for m in payload["messages"] if m["role"] == "user"] + assert "foo" in user_msgs[0]["content"] + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — generate_multimodal +# ═══════════════════════════════════════════════════════════════════════════ + +class TestGenerateMultimodal: + def test_multimodal_text_only(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response): + resp = harness.generate_multimodal("Describe this") + + assert resp.error is None + + def test_multimodal_with_base64_image(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate_multimodal( + "What is in this image?", + images=[{"type": "base64", "data": "abc123", "mime": "image/jpeg"}], + ) + + payload = mock_post.call_args[1]["json"] + content = payload["messages"][0]["content"] + image_parts = [p for p in content if p.get("type") == "image_url"] + assert len(image_parts) == 1 + assert "data:image/jpeg;base64,abc123" in image_parts[0]["image_url"]["url"] + + def test_multimodal_with_url_image(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response) as mock_post: + harness.generate_multimodal( + "What is this?", + images=[{"type": "url", "url": "http://example.com/img.png"}], + ) + + payload = mock_post.call_args[1]["json"] + content = payload["messages"][0]["content"] + image_parts = [p for p in content if p.get("type") == "image_url"] + assert image_parts[0]["image_url"]["url"] == "http://example.com/img.png" + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — session stats +# ═══════════════════════════════════════════════════════════════════════════ + +class TestSessionStats: + def test_session_stats_initial(self, harness): + stats = harness._session_stats() + assert stats["request_count"] == 0 + assert stats["total_input_tokens"] == 0 + assert stats["total_output_tokens"] == 0 + assert stats["total_cost_usd"] == 0.0 + assert stats["session_id"] == harness.session_id + + def test_session_stats_after_calls(self, harness, mock_ok_response): + with patch("requests.post", return_value=mock_ok_response): + harness.generate("a") + harness.generate("b") + + stats = harness._session_stats() + assert stats["request_count"] == 2 + assert stats["total_input_tokens"] == 20 + assert stats["total_output_tokens"] == 10 + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — orchestration registration +# ═══════════════════════════════════════════════════════════════════════════ + +class TestOrchestrationRegistration: + def test_register_success(self, harness): + mock_resp = MagicMock() + mock_resp.status_code = 201 + + with patch("requests.post", return_value=mock_resp): + result = harness.register_in_orchestration("http://localhost:8000/api/v1/workers/register") + + assert result is True + + def test_register_failure_returns_false(self, harness): + mock_resp = MagicMock() + mock_resp.status_code = 500 + mock_resp.text = "Internal error" + + with patch("requests.post", return_value=mock_resp): + result = harness.register_in_orchestration("http://localhost:8000/api/v1/workers/register") + + assert result is False + + def test_register_connection_error_returns_false(self, harness): + with patch("requests.post", side_effect=Exception("Connection refused")): + result = harness.register_in_orchestration("http://localhost:9999/register") + + assert result is False + + def test_register_payload_contains_capabilities(self, harness): + mock_resp = MagicMock() + mock_resp.status_code = 200 + + with patch("requests.post", return_value=mock_resp) as mock_post: + harness.register_in_orchestration("http://localhost/register") + + payload = mock_post.call_args[1]["json"] + assert payload["worker_id"] == HARNESS_ID + assert "text" in payload["capabilities"] + assert "multimodal" in payload["capabilities"] + assert "streaming" in payload["capabilities"] + assert "code" in payload["capabilities"] + assert len(payload["fallback_chain"]) == 3 + + +# ═══════════════════════════════════════════════════════════════════════════ +# GeminiHarness — async lifecycle (Hermes WS) +# ═══════════════════════════════════════════════════════════════════════════ + +class TestAsyncLifecycle: + @pytest.mark.asyncio + async def test_start_without_hermes(self, harness): + """Start should succeed even if Hermes is not reachable.""" + harness.hermes_ws_url = "ws://localhost:19999/ws" + # Should not raise + await harness.start() + assert harness._ws_connected is False + + @pytest.mark.asyncio + async def test_stop_without_connection(self, harness): + """Stop should succeed gracefully with no WS connection.""" + await harness.stop() + + +# ═══════════════════════════════════════════════════════════════════════════ +# HTTP server smoke test +# ═══════════════════════════════════════════════════════════════════════════ + +class TestHTTPServer: + def test_create_app_returns_classes(self, harness): + from nexus.gemini_harness import create_app + HTTPServer, GeminiHandler = create_app(harness) + assert HTTPServer is not None + assert GeminiHandler is not None + + def test_health_handler(self, harness): + """Verify health endpoint handler logic via direct method call.""" + from nexus.gemini_harness import create_app + _, GeminiHandler = create_app(harness) + + # Instantiate handler without a real socket + handler = GeminiHandler.__new__(GeminiHandler) + # _send_json should produce correct output + responses = [] + handler._send_json = lambda data, status=200: responses.append((status, data)) + handler.path = "/health" + handler.do_GET() + assert len(responses) == 1 + assert responses[0][0] == 200 + assert responses[0][1]["status"] == "ok" + assert responses[0][1]["harness"] == HARNESS_ID + + def test_status_handler(self, harness, mock_ok_response): + from nexus.gemini_harness import create_app + _, GeminiHandler = create_app(harness) + + handler = GeminiHandler.__new__(GeminiHandler) + responses = [] + handler._send_json = lambda data, status=200: responses.append((status, data)) + handler.path = "/status" + handler.do_GET() + + assert responses[0][1]["request_count"] == 0 + assert responses[0][1]["model"] == harness.model + + def test_unknown_get_returns_404(self, harness): + from nexus.gemini_harness import create_app + _, GeminiHandler = create_app(harness) + + handler = GeminiHandler.__new__(GeminiHandler) + responses = [] + handler._send_json = lambda data, status=200: responses.append((status, data)) + handler.path = "/nonexistent" + handler.do_GET() + + assert responses[0][0] == 404 + + +# ═══════════════════════════════════════════════════════════════════════════ +# Live API tests (skipped unless RUN_LIVE_TESTS=1 and GOOGLE_API_KEY set) +# ═══════════════════════════════════════════════════════════════════════════ + +def _live_tests_enabled(): + return ( + os.environ.get("RUN_LIVE_TESTS") == "1" + and bool(os.environ.get("GOOGLE_API_KEY")) + ) + + +@pytest.mark.skipif( + not _live_tests_enabled(), + reason="Live tests require RUN_LIVE_TESTS=1 and GOOGLE_API_KEY", +) +class TestLiveAPI: + """Integration tests that hit the real Gemini API.""" + + @pytest.fixture + def live_harness(self): + return GeminiHarness() + + def test_live_generate(self, live_harness): + resp = live_harness.generate("Say 'pong' and nothing else.") + assert resp.error is None + assert resp.text.strip().lower().startswith("pong") + assert resp.input_tokens > 0 + assert resp.latency_ms > 0 + + def test_live_generate_code(self, live_harness): + resp = live_harness.generate_code("write a function that returns 42", language="python") + assert resp.error is None + assert "42" in resp.text + + def test_live_stream(self, live_harness): + chunks = list(live_harness.stream_generate("Count to 3: one, two, three.")) + assert len(chunks) > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])