forked from Rockachopa/Timmy-time-dashboard
319 lines
9.8 KiB
Python
319 lines
9.8 KiB
Python
"""Provider API call mixin for the Cascade Router.
|
|
|
|
Contains methods for calling individual LLM provider APIs
|
|
(Ollama, OpenAI, Anthropic, Grok, vllm-mlx).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import base64
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from config import settings
|
|
|
|
from .models import ContentType, Provider
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ProviderCallsMixin:
|
|
"""Mixin providing LLM provider API call methods.
|
|
|
|
Expects the consuming class to have:
|
|
- self.config: RouterConfig
|
|
"""
|
|
|
|
async def _try_provider(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
content_type: ContentType = ContentType.TEXT,
|
|
) -> dict:
|
|
"""Try a single provider request."""
|
|
start_time = time.time()
|
|
|
|
if provider.type == "ollama":
|
|
result = await self._call_ollama(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
content_type=content_type,
|
|
)
|
|
elif provider.type == "openai":
|
|
result = await self._call_openai(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
elif provider.type == "anthropic":
|
|
result = await self._call_anthropic(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
elif provider.type == "grok":
|
|
result = await self._call_grok(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
elif provider.type == "vllm_mlx":
|
|
result = await self._call_vllm_mlx(
|
|
provider=provider,
|
|
messages=messages,
|
|
model=model or provider.get_default_model(),
|
|
temperature=temperature,
|
|
max_tokens=max_tokens,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown provider type: {provider.type}")
|
|
|
|
latency_ms = (time.time() - start_time) * 1000
|
|
result["latency_ms"] = latency_ms
|
|
|
|
return result
|
|
|
|
async def _call_ollama(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None = None,
|
|
content_type: ContentType = ContentType.TEXT,
|
|
) -> dict:
|
|
"""Call Ollama API with multi-modal support."""
|
|
import aiohttp
|
|
|
|
url = f"{provider.url or settings.ollama_url}/api/chat"
|
|
|
|
# Transform messages for Ollama format (including images)
|
|
transformed_messages = self._transform_messages_for_ollama(messages)
|
|
|
|
options: dict[str, Any] = {"temperature": temperature}
|
|
if max_tokens:
|
|
options["num_predict"] = max_tokens
|
|
|
|
payload = {
|
|
"model": model,
|
|
"messages": transformed_messages,
|
|
"stream": False,
|
|
"options": options,
|
|
}
|
|
|
|
timeout = aiohttp.ClientTimeout(total=self.config.timeout_seconds)
|
|
|
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
async with session.post(url, json=payload) as response:
|
|
if response.status != 200:
|
|
text = await response.text()
|
|
raise RuntimeError(f"Ollama error {response.status}: {text}")
|
|
|
|
data = await response.json()
|
|
return {
|
|
"content": data["message"]["content"],
|
|
"model": model,
|
|
}
|
|
|
|
def _transform_messages_for_ollama(self, messages: list[dict]) -> list[dict]:
|
|
"""Transform messages to Ollama format, handling images."""
|
|
transformed = []
|
|
|
|
for msg in messages:
|
|
new_msg: dict[str, Any] = {
|
|
"role": msg.get("role", "user"),
|
|
"content": msg.get("content", ""),
|
|
}
|
|
|
|
# Handle images
|
|
images = msg.get("images", [])
|
|
if images:
|
|
new_msg["images"] = []
|
|
for img in images:
|
|
if isinstance(img, str):
|
|
if img.startswith("data:image/"):
|
|
# Base64 encoded image
|
|
new_msg["images"].append(img.split(",")[1])
|
|
elif img.startswith("http://") or img.startswith("https://"):
|
|
# URL - would need to download, skip for now
|
|
logger.warning("Image URLs not yet supported, skipping: %s", img)
|
|
elif Path(img).exists():
|
|
# Local file path - read and encode
|
|
try:
|
|
with open(img, "rb") as f:
|
|
img_data = base64.b64encode(f.read()).decode()
|
|
new_msg["images"].append(img_data)
|
|
except Exception as exc:
|
|
logger.error("Failed to read image %s: %s", img, exc)
|
|
|
|
transformed.append(new_msg)
|
|
|
|
return transformed
|
|
|
|
async def _call_openai(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call OpenAI API."""
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI(
|
|
api_key=provider.api_key,
|
|
base_url=provider.base_url,
|
|
timeout=self.config.timeout_seconds,
|
|
)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"model": response.model,
|
|
}
|
|
|
|
async def _call_anthropic(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call Anthropic API."""
|
|
import anthropic
|
|
|
|
client = anthropic.AsyncAnthropic(
|
|
api_key=provider.api_key,
|
|
timeout=self.config.timeout_seconds,
|
|
)
|
|
|
|
# Convert messages to Anthropic format
|
|
system_msg = None
|
|
conversation = []
|
|
for msg in messages:
|
|
if msg["role"] == "system":
|
|
system_msg = msg["content"]
|
|
else:
|
|
conversation.append(
|
|
{
|
|
"role": msg["role"],
|
|
"content": msg["content"],
|
|
}
|
|
)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": conversation,
|
|
"temperature": temperature,
|
|
"max_tokens": max_tokens or 1024,
|
|
}
|
|
if system_msg:
|
|
kwargs["system"] = system_msg
|
|
|
|
response = await client.messages.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.content[0].text,
|
|
"model": response.model,
|
|
}
|
|
|
|
async def _call_grok(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call xAI Grok API via OpenAI-compatible SDK."""
|
|
import httpx
|
|
import openai
|
|
|
|
client = openai.AsyncOpenAI(
|
|
api_key=provider.api_key,
|
|
base_url=provider.base_url or settings.xai_base_url,
|
|
timeout=httpx.Timeout(300.0),
|
|
)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"model": response.model,
|
|
}
|
|
|
|
async def _call_vllm_mlx(
|
|
self,
|
|
provider: Provider,
|
|
messages: list[dict],
|
|
model: str,
|
|
temperature: float,
|
|
max_tokens: int | None,
|
|
) -> dict:
|
|
"""Call vllm-mlx via its OpenAI-compatible API.
|
|
|
|
vllm-mlx exposes the same /v1/chat/completions endpoint as OpenAI,
|
|
so we reuse the OpenAI client pointed at the local server.
|
|
No API key is required for local deployments.
|
|
"""
|
|
import openai
|
|
|
|
base_url = provider.base_url or provider.url or "http://localhost:8000"
|
|
# Ensure the base_url ends with /v1 as expected by the OpenAI client
|
|
if not base_url.rstrip("/").endswith("/v1"):
|
|
base_url = base_url.rstrip("/") + "/v1"
|
|
|
|
client = openai.AsyncOpenAI(
|
|
api_key=provider.api_key or "no-key-required",
|
|
base_url=base_url,
|
|
timeout=self.config.timeout_seconds,
|
|
)
|
|
|
|
kwargs: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": temperature,
|
|
}
|
|
if max_tokens:
|
|
kwargs["max_tokens"] = max_tokens
|
|
|
|
response = await client.chat.completions.create(**kwargs)
|
|
|
|
return {
|
|
"content": response.choices[0].message.content,
|
|
"model": response.model,
|
|
}
|