74 lines
3.1 KiB
Python
74 lines
3.1 KiB
Python
"""llama_provider.py — Hermes inference router provider for llama.cpp."""
|
|
import logging, os, time
|
|
from dataclasses import dataclass
|
|
from typing import Optional
|
|
from bin.llama_client import ChatMessage, LlamaClient
|
|
|
|
logger = logging.getLogger("nexus.llama_provider")
|
|
|
|
LLAMA_ENDPOINT = os.environ.get("LLAMA_ENDPOINT", "http://localhost:11435")
|
|
LLAMA_MODEL = os.environ.get("LLAMA_MODEL", "qwen2.5-7b")
|
|
LOCAL_ONLY = os.environ.get("LOCAL_ONLY", "false").lower() in ("true", "1", "yes")
|
|
FALLBACK_ON_FAILURE = os.environ.get("LLAMA_FALLBACK", "true").lower() in ("true", "1", "yes")
|
|
|
|
@dataclass
|
|
class ProviderResult:
|
|
text: str
|
|
provider: str = "llama.cpp"
|
|
model: str = ""
|
|
tokens_used: int = 0
|
|
latency_ms: float = 0.0
|
|
finish_reason: str = ""
|
|
is_local: bool = True
|
|
error: Optional[str] = None
|
|
|
|
class LlamaProvider:
|
|
def __init__(self, endpoint=LLAMA_ENDPOINT, model=LLAMA_MODEL, local_only=LOCAL_ONLY):
|
|
self.client = LlamaClient(endpoint=endpoint, model=model)
|
|
self.local_only = local_only
|
|
self.endpoint = endpoint
|
|
self._last_health = None
|
|
self._last_check = 0.0
|
|
|
|
def available(self):
|
|
now = time.time()
|
|
if self._last_health is not None and (now - self._last_check) < 30:
|
|
return self._last_health
|
|
status = self.client.health_check()
|
|
self._last_health = status.healthy and status.model_loaded
|
|
self._last_check = now
|
|
if not self._last_health:
|
|
logger.warning("llama.cpp unhealthy: %s", status.error or "model not loaded")
|
|
return self._last_health
|
|
|
|
def infer(self, messages, max_tokens=512, temperature=0.7, model=None, **kwargs):
|
|
if not self.available():
|
|
return ProviderResult(text="", error=f"llama.cpp at {self.endpoint} unavailable")
|
|
chat_msgs = [ChatMessage(m["role"], m["content"]) for m in messages if "role" in m and "content" in m]
|
|
if not chat_msgs:
|
|
return ProviderResult(text="", error="No valid messages")
|
|
start = time.time()
|
|
try:
|
|
resp = self.client.chat(chat_msgs, max_tokens=max_tokens, temperature=temperature)
|
|
return ProviderResult(text=resp.text, provider="llama.cpp",
|
|
model=resp.model or self.client.model, tokens_used=resp.tokens_used,
|
|
latency_ms=(time.time()-start)*1000, finish_reason=resp.finish_reason, is_local=True)
|
|
except Exception as e:
|
|
logger.error("llama.cpp failed: %s", e)
|
|
return ProviderResult(text="", error=str(e))
|
|
|
|
def should_use_local(self, external_failed=False, explicit_local=False):
|
|
if self.local_only: return True
|
|
if explicit_local: return True
|
|
if external_failed and FALLBACK_ON_FAILURE: return self.available()
|
|
return False
|
|
|
|
def status(self):
|
|
h = self.client.health_check()
|
|
return {"provider": "llama.cpp", "endpoint": self.endpoint,
|
|
"healthy": h.healthy, "model_loaded": h.model_loaded,
|
|
"model_name": h.model_name, "local_only": self.local_only}
|
|
|
|
def get_name(self): return "llama.cpp"
|
|
def get_priority(self): return 0 if self.local_only else 100
|