154 lines
7.1 KiB
Python
154 lines
7.1 KiB
Python
#!/usr/bin/env python3
|
|
"""llama_client.py — OpenAI-compatible client for llama.cpp HTTP API."""
|
|
import argparse, json, os, sys, time
|
|
from dataclasses import dataclass
|
|
import urllib.request, urllib.error
|
|
|
|
DEFAULT_ENDPOINT = os.environ.get("LLAMA_ENDPOINT", "http://localhost:11435")
|
|
DEFAULT_MODEL = os.environ.get("LLAMA_MODEL", "qwen2.5-7b")
|
|
DEFAULT_MAX_TOKENS = int(os.environ.get("LLAMA_MAX_TOKENS", "512"))
|
|
DEFAULT_TEMPERATURE = float(os.environ.get("LLAMA_TEMPERATURE", "0.7"))
|
|
|
|
@dataclass
|
|
class ChatMessage:
|
|
role: str
|
|
content: str
|
|
|
|
@dataclass
|
|
class CompletionResponse:
|
|
text: str
|
|
tokens_used: int = 0
|
|
latency_ms: float = 0.0
|
|
model: str = ""
|
|
finish_reason: str = ""
|
|
|
|
@dataclass
|
|
class HealthStatus:
|
|
healthy: bool
|
|
endpoint: str
|
|
model_loaded: bool = False
|
|
model_name: str = ""
|
|
error: str = ""
|
|
|
|
def _http_post(url, data, timeout=120):
|
|
body = json.dumps(data).encode()
|
|
req = urllib.request.Request(url, data=body, headers={"Content-Type": "application/json"}, method="POST")
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read())
|
|
|
|
def _http_get(url, timeout=10):
|
|
req = urllib.request.Request(url, headers={"Accept": "application/json"})
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read())
|
|
|
|
class LlamaClient:
|
|
def __init__(self, endpoint=DEFAULT_ENDPOINT, model=DEFAULT_MODEL):
|
|
self.endpoint = endpoint.rstrip("/")
|
|
self.model = model
|
|
|
|
def health_check(self) -> HealthStatus:
|
|
try:
|
|
data = _http_get(f"{self.endpoint}/health")
|
|
return HealthStatus(healthy=True, endpoint=self.endpoint,
|
|
model_loaded=data.get("status") == "ok" or data.get("model_loaded", False),
|
|
model_name=data.get("model_path", self.model))
|
|
except Exception as e:
|
|
return HealthStatus(healthy=False, endpoint=self.endpoint, error=str(e))
|
|
|
|
def is_healthy(self) -> bool:
|
|
return self.health_check().healthy
|
|
|
|
def list_models(self) -> list:
|
|
try:
|
|
data = _http_get(f"{self.endpoint}/v1/models")
|
|
return data.get("data", [])
|
|
except Exception:
|
|
return []
|
|
|
|
def chat(self, messages, max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE, stream=False):
|
|
payload = {"model": self.model,
|
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
"max_tokens": max_tokens, "temperature": temperature, "stream": stream}
|
|
start = time.time()
|
|
data = _http_post(f"{self.endpoint}/v1/chat/completions", payload)
|
|
latency = (time.time() - start) * 1000
|
|
choice = data.get("choices", [{}])[0]
|
|
msg = choice.get("message", {})
|
|
usage = data.get("usage", {})
|
|
return CompletionResponse(text=msg.get("content", ""),
|
|
tokens_used=usage.get("total_tokens", 0), latency_ms=latency,
|
|
model=data.get("model", self.model), finish_reason=choice.get("finish_reason", ""))
|
|
|
|
def chat_stream(self, messages, max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE):
|
|
payload = {"model": self.model,
|
|
"messages": [{"role": m.role, "content": m.content} for m in messages],
|
|
"max_tokens": max_tokens, "temperature": temperature, "stream": True}
|
|
req = urllib.request.Request(f"{self.endpoint}/v1/chat/completions",
|
|
data=json.dumps(payload).encode(), headers={"Content-Type": "application/json"}, method="POST")
|
|
with urllib.request.urlopen(req, timeout=300) as resp:
|
|
for line in resp:
|
|
line = line.decode().strip()
|
|
if line.startswith("data: "):
|
|
chunk = line[6:]
|
|
if chunk == "[DONE]": break
|
|
try:
|
|
data = json.loads(chunk)
|
|
content = data.get("choices", [{}])[0].get("delta", {}).get("content", "")
|
|
if content: yield content
|
|
except json.JSONDecodeError: continue
|
|
|
|
def simple_chat(self, prompt, system=None, max_tokens=DEFAULT_MAX_TOKENS):
|
|
messages = []
|
|
if system: messages.append(ChatMessage(role="system", content=system))
|
|
messages.append(ChatMessage(role="user", content=prompt))
|
|
return self.chat(messages, max_tokens=max_tokens).text
|
|
|
|
def complete(self, prompt, max_tokens=DEFAULT_MAX_TOKENS, temperature=DEFAULT_TEMPERATURE):
|
|
payload = {"prompt": prompt, "n_predict": max_tokens, "temperature": temperature}
|
|
start = time.time()
|
|
data = _http_post(f"{self.endpoint}/completion", payload)
|
|
return CompletionResponse(text=data.get("content", ""),
|
|
tokens_used=data.get("tokens_predicted", 0), latency_ms=(time.time()-start)*1000, model=self.model)
|
|
|
|
def benchmark(self, prompt="Explain sovereignty in 3 sentences.", iterations=5, max_tokens=128):
|
|
latencies, token_counts = [], []
|
|
for _ in range(iterations):
|
|
resp = self.chat([ChatMessage(role="user", content=prompt)], max_tokens=max_tokens)
|
|
latencies.append(resp.latency_ms)
|
|
token_counts.append(resp.tokens_used)
|
|
avg_lat = sum(latencies)/len(latencies)
|
|
avg_tok = sum(token_counts)/len(token_counts)
|
|
return {"iterations": iterations, "prompt": prompt,
|
|
"avg_latency_ms": round(avg_lat, 1), "min_latency_ms": round(min(latencies), 1),
|
|
"max_latency_ms": round(max(latencies), 1), "avg_tokens": round(avg_tok, 1),
|
|
"tok_per_sec": round((avg_tok/avg_lat)*1000 if avg_lat > 0 else 0, 1)}
|
|
|
|
def main():
|
|
p = argparse.ArgumentParser(description="llama.cpp client CLI")
|
|
p.add_argument("--url", default=DEFAULT_ENDPOINT)
|
|
p.add_argument("--model", default=DEFAULT_MODEL)
|
|
sub = p.add_subparsers(dest="cmd")
|
|
sub.add_parser("health")
|
|
sub.add_parser("models")
|
|
cp = sub.add_parser("chat"); cp.add_argument("prompt"); cp.add_argument("--system"); cp.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_TOKENS); cp.add_argument("--stream", action="store_true")
|
|
bp = sub.add_parser("benchmark"); bp.add_argument("--prompt", default="Explain sovereignty."); bp.add_argument("--iterations", type=int, default=5); bp.add_argument("--max-tokens", type=int, default=128)
|
|
args = p.parse_args()
|
|
client = LlamaClient(args.url, args.model)
|
|
if args.cmd == "health":
|
|
print(json.dumps(client.health_check().__dict__, indent=2)); sys.exit(0 if client.is_healthy() else 1)
|
|
elif args.cmd == "models":
|
|
print(json.dumps(client.list_models(), indent=2))
|
|
elif args.cmd == "chat":
|
|
if args.stream:
|
|
msgs = []
|
|
if args.system: msgs.append(ChatMessage("system", args.system))
|
|
msgs.append(ChatMessage("user", args.prompt))
|
|
for chunk in client.chat_stream(msgs, max_tokens=args.max_tokens): print(chunk, end="", flush=True)
|
|
print()
|
|
else: print(client.simple_chat(args.prompt, system=args.system, max_tokens=args.max_tokens))
|
|
elif args.cmd == "benchmark":
|
|
print(json.dumps(client.benchmark(args.prompt, args.iterations, args.max_tokens), indent=2))
|
|
else: p.print_help()
|
|
|
|
if __name__ == "__main__": main()
|