"""Provider API call mixin for the Cascade Router. Contains methods for calling individual LLM provider APIs (Ollama, OpenAI, Anthropic, Grok, vllm-mlx). """ from __future__ import annotations import base64 import logging import time from pathlib import Path from typing import Any from config import settings from .models import ContentType, Provider logger = logging.getLogger(__name__) class ProviderCallsMixin: """Mixin providing LLM provider API call methods. Expects the consuming class to have: - self.config: RouterConfig """ async def _try_provider( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None, content_type: ContentType = ContentType.TEXT, ) -> dict: """Try a single provider request.""" start_time = time.time() if provider.type == "ollama": result = await self._call_ollama( provider=provider, messages=messages, model=model or provider.get_default_model(), temperature=temperature, max_tokens=max_tokens, content_type=content_type, ) elif provider.type == "openai": result = await self._call_openai( provider=provider, messages=messages, model=model or provider.get_default_model(), temperature=temperature, max_tokens=max_tokens, ) elif provider.type == "anthropic": result = await self._call_anthropic( provider=provider, messages=messages, model=model or provider.get_default_model(), temperature=temperature, max_tokens=max_tokens, ) elif provider.type == "grok": result = await self._call_grok( provider=provider, messages=messages, model=model or provider.get_default_model(), temperature=temperature, max_tokens=max_tokens, ) elif provider.type == "vllm_mlx": result = await self._call_vllm_mlx( provider=provider, messages=messages, model=model or provider.get_default_model(), temperature=temperature, max_tokens=max_tokens, ) else: raise ValueError(f"Unknown provider type: {provider.type}") latency_ms = (time.time() - start_time) * 1000 result["latency_ms"] = latency_ms return result async def _call_ollama( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None = None, content_type: ContentType = ContentType.TEXT, ) -> dict: """Call Ollama API with multi-modal support.""" import aiohttp url = f"{provider.url or settings.ollama_url}/api/chat" # Transform messages for Ollama format (including images) transformed_messages = self._transform_messages_for_ollama(messages) options: dict[str, Any] = {"temperature": temperature} if max_tokens: options["num_predict"] = max_tokens payload = { "model": model, "messages": transformed_messages, "stream": False, "options": options, } timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds) async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, json=payload) as response: if response.status != 200: text = await response.text() raise RuntimeError(f"Ollama error {response.status}: {text}") data = await response.json() return { "content": data["message"]["content"], "model": model, } def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]: """Transform messages to Ollama format, handling images.""" transformed = [] for msg in messages: new_msg: dict[str, Any] = { "role": msg.get("role", "user"), "content": msg.get("content", ""), } # Handle images images = msg.get("images", []) if images: new_msg["images"] = [] for img in images: if isinstance(img, str): if img.startswith("data:image/"): # Base64 encoded image new_msg["images"].append(img.split(",")[1]) elif img.startswith("http://") or img.startswith("https://"): # URL - would need to download, skip for now logger.warning("Image URLs not yet supported, skipping: %s", img) elif Path(img).exists(): # Local file path - read and encode try: with open(img, "rb") as f: img_data = base64.b64encode(f.read()).decode() new_msg["images"].append(img_data) except Exception as exc: logger.error("Failed to read image %s: %s", img, exc) transformed.append(new_msg) return transformed async def _call_openai( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None, ) -> dict: """Call OpenAI API.""" import openai client = openai.AsyncOpenAI( api_key=provider.api_key, base_url=provider.base_url, timeout=self.config.timeout_seconds, ) kwargs: dict[str, Any] = { "model": model, "messages": messages, "temperature": temperature, } if max_tokens: kwargs["max_tokens"] = max_tokens response = await client.chat.completions.create(**kwargs) return { "content": response.choices[0].message.content, "model": response.model, } async def _call_anthropic( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None, ) -> dict: """Call Anthropic API.""" import anthropic client = anthropic.AsyncAnthropic( api_key=provider.api_key, timeout=self.config.timeout_seconds, ) # Convert messages to Anthropic format system_msg = None conversation = [] for msg in messages: if msg["role"] == "system": system_msg = msg["content"] else: conversation.append( { "role": msg["role"], "content": msg["content"], } ) kwargs: dict[str, Any] = { "model": model, "messages": conversation, "temperature": temperature, "max_tokens": max_tokens or 1024, } if system_msg: kwargs["system"] = system_msg response = await client.messages.create(**kwargs) return { "content": response.content[0].text, "model": response.model, } async def _call_grok( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None, ) -> dict: """Call xAI Grok API via OpenAI-compatible SDK.""" import httpx import openai client = openai.AsyncOpenAI( api_key=provider.api_key, base_url=provider.base_url or settings.xai_base_url, timeout=httpx.Timeout(300.0), ) kwargs: dict[str, Any] = { "model": model, "messages": messages, "temperature": temperature, } if max_tokens: kwargs["max_tokens"] = max_tokens response = await client.chat.completions.create(**kwargs) return { "content": response.choices[0].message.content, "model": response.model, } async def _call_vllm_mlx( self, provider: Provider, messages: list[dict], model: str, temperature: float, max_tokens: int | None, ) -> dict: """Call vllm-mlx via its OpenAI-compatible API. vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI, so we reuse the OpenAI client pointed at the local server. No API key is required for local deployments. """ import openai base_url = provider.base_url or provider.url or "http://localhost:8000" # Ensure the base_url ends with /v1 as expected by the OpenAI client if not base_url.rstrip("/").endswith("/v1"): base_url = base_url.rstrip("/") + "/v1" client = openai.AsyncOpenAI( api_key=provider.api_key or "no-key-required", base_url=base_url, timeout=self.config.timeout_seconds, ) kwargs: dict[str, Any] = { "model": model, "messages": messages, "temperature": temperature, } if max_tokens: kwargs["max_tokens"] = max_tokens response = await client.chat.completions.create(**kwargs) return { "content": response.choices[0].message.content, "model": response.model, }