73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import yaml
|
|
|
|
|
|
DEFAULT_CONFIG_PATH = Path(__file__).resolve().parents[1] / "config.yaml"
|
|
|
|
|
|
def _normalize_base_url(base_url: str) -> str:
|
|
return (base_url or "").rstrip("/")
|
|
|
|
|
|
def load_big_brain_provider(config_path: str | Path = DEFAULT_CONFIG_PATH) -> dict[str, Any]:
|
|
config = yaml.safe_load(Path(config_path).read_text()) or {}
|
|
for provider in config.get("custom_providers", []):
|
|
if provider.get("name") == "Big Brain":
|
|
return dict(provider)
|
|
raise KeyError("Big Brain provider not found in config")
|
|
|
|
|
|
def infer_backend(base_url: str) -> str:
|
|
base = _normalize_base_url(base_url)
|
|
return "openai" if base.endswith("/v1") else "ollama"
|
|
|
|
|
|
def resolve_big_brain_provider(config_path: str | Path = DEFAULT_CONFIG_PATH) -> dict[str, Any]:
|
|
provider = load_big_brain_provider(config_path)
|
|
base_url = _normalize_base_url(os.environ.get("BIG_BRAIN_BASE_URL", provider.get("base_url", "")))
|
|
model = os.environ.get("BIG_BRAIN_MODEL", provider.get("model", "gemma4:latest"))
|
|
backend = os.environ.get("BIG_BRAIN_BACKEND", infer_backend(base_url))
|
|
api_key = os.environ.get("BIG_BRAIN_API_KEY", provider.get("api_key", ""))
|
|
return {
|
|
"name": provider.get("name", "Big Brain"),
|
|
"base_url": base_url,
|
|
"model": model,
|
|
"backend": backend,
|
|
"api_key": api_key,
|
|
}
|
|
|
|
|
|
def resolve_models_url(provider: dict[str, Any]) -> str:
|
|
base = _normalize_base_url(provider["base_url"])
|
|
if provider["backend"] == "openai":
|
|
return f"{base}/models"
|
|
return f"{base}/api/tags"
|
|
|
|
|
|
def resolve_generate_url(provider: dict[str, Any]) -> str:
|
|
base = _normalize_base_url(provider["base_url"])
|
|
if provider["backend"] == "openai":
|
|
return f"{base}/chat/completions"
|
|
return f"{base}/api/generate"
|
|
|
|
|
|
def build_generate_payload(provider: dict[str, Any], prompt: str = "Say READY") -> dict[str, Any]:
|
|
if provider["backend"] == "openai":
|
|
return {
|
|
"model": provider["model"],
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"stream": False,
|
|
"max_tokens": 32,
|
|
}
|
|
return {
|
|
"model": provider["model"],
|
|
"prompt": prompt,
|
|
"stream": False,
|
|
"options": {"num_predict": 32},
|
|
}
|